Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration/rust/tests/integration/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ pub mod timestamp_sorting;
pub mod tls_enforced;
pub mod tls_reload;
pub mod transaction_state;
pub mod unrecognized_aggregate;
51 changes: 51 additions & 0 deletions integration/rust/tests/integration/unrecognized_aggregate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use rust::setup::{admin_sqlx, connections_sqlx};
use sqlx::Executor;
use std::assert_matches;

async fn define_custom_aggregate_fn() {
for connection in connections_sqlx().await {
connection
.execute("DROP AGGREGATE IF EXISTS pgdog_sum (int4)")
.await
.unwrap();
connection
.execute("CREATE AGGREGATE pgdog_sum (int4) (sfunc = int4_sum, stype = bigint)")
.await
.unwrap();
connection
.execute("DROP TABLE IF EXISTS unrecognized_agg_test")
.await
.unwrap();
connection
.execute("CREATE TABLE unrecognized_agg_test (lol int4)")
.await
.unwrap();
}
admin_sqlx().await.execute("RELOAD").await.unwrap();
}

#[tokio::test]
async fn unrecognized_aggregate_function_errors_only_on_cross_shard_queries() {
define_custom_aggregate_fn().await;
let mut connections = connections_sqlx().await;
let sharded = connections.pop().unwrap();
let unsharded = connections.pop().unwrap();

let unsharded_query = unsharded
.fetch_one("SELECT pgdog_sum(lol) FROM unrecognized_agg_test")
.await;
assert_matches!(unsharded_query, Ok(_));

let sharded_query = sharded
.fetch_one("SELECT pgdog_sum(lol) FROM unrecognized_agg_test")
.await;
let err = sharded_query
.err()
.expect("unrecognized aggregate executed successfully");
assert!(err.to_string().contains("pgdog_sum is not yet supported"));

let direct_to_shard_query = sharded
.fetch_one("/* pgdog_shard: 0 */ SELECT pgdog_sum(lol) FROM unrecognized_agg_test")
.await;
assert_matches!(direct_to_shard_query, Ok(_));
}
8 changes: 7 additions & 1 deletion pgdog-stats/src/schema.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, hash::Hash, ops::Deref, sync::Arc};
use std::{
collections::{HashMap, HashSet},
hash::Hash,
ops::Deref,
sync::Arc,
};

/// Schema name -> Table name -> Relation
pub type Relations = HashMap<String, HashMap<String, Relation>>;
Expand Down Expand Up @@ -134,6 +139,7 @@ impl Relation {
pub struct SchemaInner {
pub search_path: Vec<String>,
pub relations: Relations,
pub aggregate_functions: HashSet<String>,
}

impl Hash for SchemaInner {
Expand Down
2 changes: 1 addition & 1 deletion pgdog/src/backend/pool/connection/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ mod test {
}

fn parse(stmt: &str) -> Aggregate {
Aggregate::parse(&select(stmt))
Aggregate::parse(&select(stmt), &Default::default()).unwrap()
}

#[test]
Expand Down
39 changes: 39 additions & 0 deletions pgdog/src/backend/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,18 @@ impl Schema {
.map(|p| p.trim().replace("\"", ""))
.collect();

let aggregate_functions = server
.fetch_all::<String>(
"SELECT DISTINCT proname FROM pg_proc INNER JOIN pg_aggregate ON oid = aggfnoid",
)
.await?
.into_iter()
.collect();

let inner = SchemaInner {
search_path,
relations,
aggregate_functions,
};

Ok(Self {
Expand All @@ -79,6 +88,15 @@ impl Schema {
pub(crate) fn from_parts(
search_path: Vec<String>,
relations: HashMap<(String, String), Relation>,
) -> Self {
Self::from_parts_with_agg(search_path, relations, Vec::new())
}

#[cfg(test)]
pub(crate) fn from_parts_with_agg(
search_path: Vec<String>,
relations: HashMap<(String, String), Relation>,
aggregate_functions: Vec<String>,
) -> Self {
let mut nested: StatsRelations = HashMap::new();
for ((schema, name), relation) in relations {
Expand All @@ -91,6 +109,7 @@ impl Schema {
inner: StatsSchema::new(SchemaInner {
search_path,
relations: nested,
aggregate_functions: aggregate_functions.into_iter().collect(),
}),
}
}
Expand Down Expand Up @@ -555,4 +574,24 @@ mod test {
);
}
}

#[tokio::test]
async fn test_loading_aggregate_functions() {
let mut server = test_server().await;
server.execute_checked("BEGIN").await.unwrap();

let schema = Schema::load(&mut server).await.unwrap();
assert!(!schema
.aggregate_functions
.contains(&String::from("pgdog_sum")));

server
.execute_checked("CREATE AGGREGATE pgdog_sum (int4) (sfunc = int4_sum, stype = bigint)")
.await
.unwrap();
let schema = Schema::load(&mut server).await.unwrap();
assert!(schema
.aggregate_functions
.contains(&String::from("pgdog_sum")));
}
}
36 changes: 31 additions & 5 deletions pgdog/src/frontend/router/parser/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use pg_query::protobuf::Integer;
use pg_query::protobuf::{a_const::Val, Node, SelectStmt, String as PgQueryString};
use pg_query::NodeEnum;

use crate::frontend::router::parser::{ExpressionRegistry, Function};
use super::{rewrite::statement::Error, ExpressionRegistry, Function};
use crate::backend::schema::Schema;

#[derive(Debug, Clone, PartialEq)]
pub struct AggregateTarget {
Expand Down Expand Up @@ -121,7 +122,7 @@ fn columns_match(group_by_names: &[&String], select_names: &[&String]) -> bool {

impl Aggregate {
/// Figure out what aggregates are present and which ones PgDog supports.
pub fn parse(stmt: &SelectStmt) -> Self {
pub fn parse(stmt: &SelectStmt, schema: &Schema) -> Result<Self, Error> {
let mut targets = vec![];
let mut registry = ExpressionRegistry::new();
let group_by = stmt
Expand Down Expand Up @@ -168,7 +169,13 @@ impl Aggregate {
"stddev_pop" => Some(AggregateFunction::StddevPop),
"variance" | "var_samp" => Some(AggregateFunction::VarSamp),
"var_pop" => Some(AggregateFunction::VarPop),
_ => None,
fname => {
if schema.aggregate_functions.contains(fname) {
Comment thread
sgrif marked this conversation as resolved.
return Err(Error::UnsupportedAggregate(String::from(fname)));
} else {
None
}
}
};

if let Some(function) = function {
Expand Down Expand Up @@ -196,7 +203,7 @@ impl Aggregate {
}
}

Self { targets, group_by }
Ok(Self { targets, group_by })
}

pub fn targets(&self) -> &[AggregateTarget] {
Expand Down Expand Up @@ -259,7 +266,7 @@ mod test {
}

fn parse(stmt: &str) -> Aggregate {
Aggregate::parse(&select(stmt))
Aggregate::parse(&select(stmt), &Default::default()).unwrap()
}

#[test]
Expand Down Expand Up @@ -455,4 +462,23 @@ mod test {
assert_eq!(aggr.group_by(), &[0]);
assert_eq!(aggr.targets().len(), 1);
}

#[test]
fn test_unrecognized_aggregate_function_errors() {
let schema_with_agg = Schema::from_parts_with_agg(
Vec::new(),
Default::default(),
vec![String::from("mysum")],
);
let schema_without_agg = Default::default();
let query = select("SELECT mysum(lol) FROM example");

// A random function that isn't listed as aggregate in the schema
// doesn't require special support on our end, so we should be fine.
assert!(Aggregate::parse(&query, &schema_without_agg).is_ok());
// If we see an aggregate function we don't recognize, we can't
// process the query correctly, since we need to combine the
// results from each shard.
assert!(Aggregate::parse(&query, &schema_with_agg).is_err());
}
}
1 change: 1 addition & 0 deletions pgdog/src/frontend/router/parser/cache/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ impl Ast {
db_schema,
user,
search_path,
is_direct_to_shard: query.is_direct_to_shard,
})
.maybe_rewrite()?;

Expand Down
41 changes: 28 additions & 13 deletions pgdog/src/frontend/router/parser/cache/cache_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use parking_lot::{Mutex, RawMutex};
use std::sync::Arc;
use tracing::debug;

use super::super::{Error, Route};
use super::{super::parse_edge_comment, Ast, AstContext, AstQuery};
use super::super::{Error, Route, ShardOptionExt};
use super::{super::parse_edge_comment, key_pair::KeyPair, Ast, AstContext, AstQuery};
use crate::frontend::{BufferedQuery, PreparedStatements};

static CACHE: Lazy<Cache> = Lazy::new(Cache::new);
Expand Down Expand Up @@ -47,7 +47,7 @@ impl Stats {
#[derive(Debug)]
pub(super) struct Inner {
/// Least-recently-used cache.
queries: LruCache<Arc<str>, Ast>,
queries: LruCache<(Arc<str>, bool), Ast>,
/// Cache global stats.
pub(super) stats: Stats,
}
Expand Down Expand Up @@ -119,10 +119,16 @@ impl Cache {

{
let mut guard = self.inner.lock();
let ast = guard.queries.get_mut(query_and_comment.query).map(|entry| {
entry.stats.lock().hits += 1; // No contention on this.
entry.clone()
});
let ast = guard
.queries
.get_mut(
&(query_and_comment.query, query_and_comment.shard.is_direct())
as &dyn KeyPair<str, bool>,
)
.map(|entry| {
entry.stats.lock().hits += 1; // No contention on this.
entry.clone()
});
if let Some(mut ast) = ast {
guard.stats.hits += 1;
ast.comment_role = query_and_comment.role;
Expand All @@ -137,6 +143,7 @@ impl Cache {
&AstQuery {
original_query: query,
query_without_comment: query_and_comment.query,
is_direct_to_shard: query_and_comment.shard.is_direct(),
},
ctx,
prepared_statements,
Expand All @@ -153,9 +160,13 @@ impl Cache {
// (direct-shard) variant.
let cacheable = entry.comment_shard.is_none() || entry.rewrite_plan.is_empty();
if cacheable {
guard
.queries
.put(entry.query_without_comment.clone(), entry.clone());
guard.queries.put(
(
entry.query_without_comment.clone(),
entry.comment_shard.is_direct(),
),
entry.clone(),
);
}
guard.stats.misses += 1;
guard.stats.parse_time += parse_time;
Expand All @@ -177,6 +188,7 @@ impl Cache {
&AstQuery {
original_query: query,
query_without_comment: query_and_comment.query,
is_direct_to_shard: query_and_comment.shard.is_direct(),
},
ctx,
prepared_statements,
Expand Down Expand Up @@ -208,7 +220,10 @@ impl Cache {

{
let mut guard = self.inner.lock();
if let Some(entry) = guard.queries.get(normalized.as_str()) {
if let Some(entry) = guard
.queries
.get(&(normalized.as_str(), false) as &dyn KeyPair<str, bool>)
{
entry.update_stats(route);
guard.stats.hits += 1;
return Ok(());
Expand All @@ -219,7 +234,7 @@ impl Cache {
entry.update_stats(route);

let mut guard = self.inner.lock();
guard.queries.put(normalized.into(), entry);
guard.queries.put((normalized.into(), false), entry);
guard.stats.misses += 1;

Ok(())
Expand Down Expand Up @@ -259,7 +274,7 @@ impl Cache {
.lock()
.queries
.iter()
.map(|i| (i.0.clone(), i.1.clone()))
.map(|i| (i.0 .0.clone(), i.1.clone()))
.collect()
}

Expand Down
4 changes: 4 additions & 0 deletions pgdog/src/frontend/router/parser/cache/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,18 @@ pub struct AstQuery<'a> {
pub original_query: &'a BufferedQuery,
/// Query without comments and other noise.
pub query_without_comment: &'a str,
/// Is this query direct-to-shard?
pub is_direct_to_shard: bool,
}

impl<'a> AstQuery<'a> {
/// Create an AstQuery using the raw query text as the cache key.
#[cfg(test)]
pub fn from_query(query: &'a BufferedQuery) -> Self {
Self {
query_without_comment: query.query(),
original_query: query,
is_direct_to_shard: false,
}
}

Expand Down
Loading
Loading