diff --git a/store/postgres/src/sql/parser_tests.yaml b/store/postgres/src/sql/parser_tests.yaml index 7a3ef9c005a..c4155a8deac 100644 --- a/store/postgres/src/sql/parser_tests.yaml +++ b/store/postgres/src/sql/parser_tests.yaml @@ -127,4 +127,7 @@ ok: SELECT * FROM (SELECT "id", "timestamp", "sum" FROM "sgd0815"."stats_hour" WHERE block$ <= 2147483647) AS sh - name: nested query with CTE sql: select *, (with pg_user as (select 1) select 1) as one from pg_user - err: Unknown table pg_user + err: Unknown table pg_user +- name: Quoted name in CTE + sql: WITH "PG_USER" AS (SELECT 1) SELECT * FROM pg_user; + err: Unknown table pg_user diff --git a/store/postgres/src/sql/validation.rs b/store/postgres/src/sql/validation.rs index 0b629e8c416..e33629e2d15 100644 --- a/store/postgres/src/sql/validation.rs +++ b/store/postgres/src/sql/validation.rs @@ -6,6 +6,7 @@ use sqlparser::ast::{ ValueWithSpan, VisitMut, VisitorMut, }; use sqlparser::parser::Parser; +use std::fmt::Display; use std::result::Result; use std::{collections::HashSet, ops::ControlFlow}; @@ -40,6 +41,39 @@ pub enum Error { InternalError(String), } +/// A wrapper around table names that correctly handles quoted vs unquoted +/// comparisons of names +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct TableName(String); + +impl TableName { + fn as_str(&self) -> &str { + &self.0 + } +} + +impl From<&Ident> for TableName { + fn from(ident: &Ident) -> Self { + let Ident { + value, + quote_style, + span: _, + } = ident; + // Use quoted names verbatim, and normalize unquoted names to + // lowercase + match quote_style { + Some(_) => Self(value.clone()), + None => Self(value.to_lowercase()), + } + } +} + +impl Display for TableName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + /// Helper to track CTEs introduced by the main query or subqueries. Every /// time we enter a query, we need to track a new set of CTEs which must be /// discarded once we are done with that query. Otherwise, we might allow @@ -47,7 +81,7 @@ pub enum Error { /// (select 1) select 1) as one from pg_user` #[derive(Default)] struct CteStack { - stack: Vec>, + stack: Vec>, } impl CteStack { @@ -59,9 +93,9 @@ impl CteStack { self.stack.pop(); } - fn contains(&self, name: &str) -> bool { + fn contains(&self, name: &TableName) -> bool { for entry in self.stack.iter().rev() { - if entry.contains(&name.to_lowercase()) { + if entry.contains(name) { return true; } } @@ -77,7 +111,7 @@ impl CteStack { return ControlFlow::Break(Error::InternalError("CTE stack is empty".into())); }; for cte in ctes { - entry.insert(cte.alias.name.value.to_lowercase()); + entry.insert(TableName::from(&cte.alias.name)); } ControlFlow::Continue(()) } @@ -254,20 +288,20 @@ impl VisitorMut for Validator<'_> { return ControlFlow::Break(Error::NoQualifiedTables(name.to_string())); } let table_name = match &name.0[0] { - ObjectNamePart::Identifier(ident) => &ident.value, + ObjectNamePart::Identifier(ident) => TableName::from(ident), ObjectNamePart::Function(_) => { return ControlFlow::Break(Error::NoQualifiedTables(name.to_string())); } }; // CTES override subgraph tables - if self.ctes.contains(&table_name.to_lowercase()) && args.is_none() { + if self.ctes.contains(&table_name) && args.is_none() { return ControlFlow::Continue(()); } - let table = match (self.layout.table(table_name), args) { + let table = match (self.layout.table(table_name.as_str()), args) { (None, None) => { - return ControlFlow::Break(Error::UnknownTable(table_name.clone())); + return ControlFlow::Break(Error::UnknownTable(table_name.to_string())); } (Some(_), Some(_)) => { // Table exists but has args, must be a function @@ -278,7 +312,7 @@ impl VisitorMut for Validator<'_> { // aggregation table in the form () or // must be a function - if !self.layout.has_aggregation(table_name) { + if !self.layout.has_aggregation(table_name.as_str()) { // Not an aggregation, must be a function return self.validate_function_name(&name); } @@ -287,23 +321,24 @@ impl VisitorMut for Validator<'_> { if settings.is_some() { // We do not support settings on aggregation tables return ControlFlow::Break(Error::InvalidAggregationSyntax( - table_name.clone(), + table_name.to_string(), )); } let Some(intv) = extract_string_arg(args) else { // Looks like an aggregation, but argument is not a single string return ControlFlow::Break(Error::InvalidAggregationSyntax( - table_name.clone(), + table_name.to_string(), )); }; let Some(intv) = intv.parse::().ok() else { return ControlFlow::Break(Error::UnknownAggregationInterval( - table_name.clone(), + table_name.to_string(), intv, )); }; - let Some(table) = self.layout.aggregation_table(table_name, intv) else { + let Some(table) = self.layout.aggregation_table(table_name.as_str(), intv) + else { return self.validate_function_name(&name); }; table @@ -312,7 +347,7 @@ impl VisitorMut for Validator<'_> { if !table.object.is_object_type() { // Interfaces and aggregations can not be queried // with the table name directly - return ControlFlow::Break(Error::UnknownTable(table_name.clone())); + return ControlFlow::Break(Error::UnknownTable(table_name.to_string())); } table }