diff --git a/pgdog/src/frontend/router/parser/cache/cache_impl.rs b/pgdog/src/frontend/router/parser/cache/cache_impl.rs index 5d66d3430..b49580e56 100644 --- a/pgdog/src/frontend/router/parser/cache/cache_impl.rs +++ b/pgdog/src/frontend/router/parser/cache/cache_impl.rs @@ -10,8 +10,8 @@ use parking_lot::{Mutex, RawMutex}; use std::sync::Arc; use tracing::debug; -use super::super::{Error, Route}; -use super::{super::parse_edge_comment, Ast, AstContext, AstQuery}; +use super::super::{Error, Route, ShardOptionExt}; +use super::{super::parse_edge_comment, key_pair::KeyPair, Ast, AstContext, AstQuery}; use crate::frontend::{BufferedQuery, PreparedStatements}; static CACHE: Lazy = Lazy::new(Cache::new); @@ -47,7 +47,7 @@ impl Stats { #[derive(Debug)] pub(super) struct Inner { /// Least-recently-used cache. - queries: LruCache, Ast>, + queries: LruCache<(Arc, bool), Ast>, /// Cache global stats. pub(super) stats: Stats, } @@ -119,10 +119,16 @@ impl Cache { { let mut guard = self.inner.lock(); - let ast = guard.queries.get_mut(query_and_comment.query).map(|entry| { - entry.stats.lock().hits += 1; // No contention on this. - entry.clone() - }); + let ast = guard + .queries + .get_mut( + &(query_and_comment.query, query_and_comment.shard.is_direct()) + as &dyn KeyPair, + ) + .map(|entry| { + entry.stats.lock().hits += 1; // No contention on this. + entry.clone() + }); if let Some(mut ast) = ast { guard.stats.hits += 1; ast.comment_role = query_and_comment.role; @@ -153,9 +159,13 @@ impl Cache { // (direct-shard) variant. let cacheable = entry.comment_shard.is_none() || entry.rewrite_plan.is_empty(); if cacheable { - guard - .queries - .put(entry.query_without_comment.clone(), entry.clone()); + guard.queries.put( + ( + entry.query_without_comment.clone(), + entry.comment_shard.is_direct(), + ), + entry.clone(), + ); } guard.stats.misses += 1; guard.stats.parse_time += parse_time; @@ -208,7 +218,10 @@ impl Cache { { let mut guard = self.inner.lock(); - if let Some(entry) = guard.queries.get(normalized.as_str()) { + if let Some(entry) = guard + .queries + .get(&(normalized.as_str(), false) as &dyn KeyPair) + { entry.update_stats(route); guard.stats.hits += 1; return Ok(()); @@ -219,7 +232,7 @@ impl Cache { entry.update_stats(route); let mut guard = self.inner.lock(); - guard.queries.put(normalized.into(), entry); + guard.queries.put((normalized.into(), false), entry); guard.stats.misses += 1; Ok(()) @@ -259,7 +272,7 @@ impl Cache { .lock() .queries .iter() - .map(|i| (i.0.clone(), i.1.clone())) + .map(|i| (i.0 .0.clone(), i.1.clone())) .collect() } diff --git a/pgdog/src/frontend/router/parser/cache/key_pair.rs b/pgdog/src/frontend/router/parser/cache/key_pair.rs new file mode 100644 index 000000000..6278375f9 --- /dev/null +++ b/pgdog/src/frontend/router/parser/cache/key_pair.rs @@ -0,0 +1,106 @@ +//! A trait object used to index into a map-ish container using a tuple key. +//! This trait is needed when we want to use borrowed types to look up owned +//! data. There is no easy way to have a type which is semantically (String, +//! String) implement Borrow<(&str, &str)>. The signature of Borrow would +//! require us to manifest a tuple reference out of thin air. +//! +//! Instead we use a trait object which lets us erase the type entirely, +//! creating a reference to each element of the tuple when needed for comparison +//! and hashing. +use std::borrow::Borrow; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +pub(crate) trait KeyPair { + fn left(&self) -> &A; + fn right(&self) -> &B; +} + +impl KeyPair for (A, B) { + fn left(&self) -> &A { + &self.0 + } + + fn right(&self) -> &B { + &self.1 + } +} + +impl KeyPair for (&A, &B) { + fn left(&self) -> &A { + self.0 + } + + fn right(&self) -> &B { + self.1 + } +} + +impl KeyPair for (&A, B) { + fn left(&self) -> &A { + self.0 + } + + fn right(&self) -> &B { + &self.1 + } +} + +impl KeyPair for (A, &B) { + fn left(&self) -> &A { + &self.0 + } + + fn right(&self) -> &B { + self.1 + } +} + +impl KeyPair for (Arc, B) { + fn left(&self) -> &A { + &self.0 + } + + fn right(&self) -> &B { + &self.1 + } +} + +impl<'a, A, B, X, Y> Borrow + 'a> for (X, Y) +where + A: ?Sized, + B: ?Sized, + (X, Y): KeyPair + 'a, +{ + fn borrow(&self) -> &(dyn KeyPair + 'a) { + self + } +} + +impl<'a, A, B> PartialEq for dyn KeyPair + 'a +where + A: PartialEq + ?Sized + 'a, + B: PartialEq + ?Sized + 'a, +{ + fn eq(&self, other: &Self) -> bool { + self.left() == other.left() && self.right() == other.right() + } +} + +impl<'a, A, B> Eq for dyn KeyPair + 'a +where + A: Eq + ?Sized + 'a, + B: Eq + ?Sized + 'a, +{ +} + +impl<'a, A, B> Hash for dyn KeyPair + 'a +where + A: ?Sized + 'a, + B: ?Sized + 'a, + for<'b, 'c> (&'b A, &'c B): Hash, +{ + fn hash(&self, state: &mut H) { + (self.left(), self.right()).hash(state) + } +} diff --git a/pgdog/src/frontend/router/parser/cache/mod.rs b/pgdog/src/frontend/router/parser/cache/mod.rs index 0a3437813..5eee7c1bc 100644 --- a/pgdog/src/frontend/router/parser/cache/mod.rs +++ b/pgdog/src/frontend/router/parser/cache/mod.rs @@ -6,6 +6,7 @@ pub mod ast; pub mod cache_impl; pub mod context; pub mod fingerprint; +pub mod key_pair; pub use ast::*; pub use cache_impl::*; diff --git a/pgdog/src/frontend/router/parser/mod.rs b/pgdog/src/frontend/router/parser/mod.rs index ccd86044a..f8973b5f9 100644 --- a/pgdog/src/frontend/router/parser/mod.rs +++ b/pgdog/src/frontend/router/parser/mod.rs @@ -55,6 +55,7 @@ pub use query::QueryParser; pub use rewrite::{ Assignment, AssignmentValue, ShardKeyRewritePlan, StatementRewrite, StatementRewriteContext, }; +pub(crate) use route::ShardOptionExt; pub use route::{Route, Shard, ShardWithPriority, ShardsWithPriority}; pub use schema::Schema; pub use sequence::{OwnedSequence, Sequence}; diff --git a/pgdog/src/frontend/router/parser/route.rs b/pgdog/src/frontend/router/parser/route.rs index 06ad11d35..b63a1866e 100644 --- a/pgdog/src/frontend/router/parser/route.rs +++ b/pgdog/src/frontend/router/parser/route.rs @@ -50,6 +50,16 @@ impl Shard { } } +pub(crate) trait ShardOptionExt { + fn is_direct(&self) -> bool; +} + +impl ShardOptionExt for Option { + fn is_direct(&self) -> bool { + self.as_ref().map(Shard::is_direct).unwrap_or(false) + } +} + impl From> for Shard { fn from(value: Option) -> Self { if let Some(value) = value {