Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion store/postgres/src/sql/parser_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,7 @@
ok: SELECT * FROM (SELECT "id", "timestamp", "sum" FROM "sgd0815"."stats_hour" WHERE block$ <= 2147483647) AS sh
- name: nested query with CTE
sql: select *, (with pg_user as (select 1) select 1) as one from pg_user
err: Unknown table pg_user
err: Unknown table pg_user
- name: Quoted name in CTE
sql: WITH "PG_USER" AS (SELECT 1) SELECT * FROM pg_user;
err: Unknown table pg_user
63 changes: 49 additions & 14 deletions store/postgres/src/sql/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use sqlparser::ast::{
ValueWithSpan, VisitMut, VisitorMut,
};
use sqlparser::parser::Parser;
use std::fmt::Display;
use std::result::Result;
use std::{collections::HashSet, ops::ControlFlow};

Expand Down Expand Up @@ -40,14 +41,47 @@ pub enum Error {
InternalError(String),
}

/// A wrapper around table names that correctly handles quoted vs unquoted
/// comparisons of names
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct TableName(String);

impl TableName {
fn as_str(&self) -> &str {
&self.0
}
}

impl From<&Ident> for TableName {
fn from(ident: &Ident) -> Self {
let Ident {
value,
quote_style,
span: _,
} = ident;
// Use quoted names verbatim, and normalize unquoted names to
// lowercase
match quote_style {
Some(_) => Self(value.clone()),
None => Self(value.to_lowercase()),
}
}
}

impl Display for TableName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

/// Helper to track CTEs introduced by the main query or subqueries. Every
/// time we enter a query, we need to track a new set of CTEs which must be
/// discarded once we are done with that query. Otherwise, we might allow
/// access to forbidden tables with a query like `select *, (with pg_user as
/// (select 1) select 1) as one from pg_user`
#[derive(Default)]
struct CteStack {
stack: Vec<HashSet<String>>,
stack: Vec<HashSet<TableName>>,
}

impl CteStack {
Expand All @@ -59,9 +93,9 @@ impl CteStack {
self.stack.pop();
}

fn contains(&self, name: &str) -> bool {
fn contains(&self, name: &TableName) -> bool {
for entry in self.stack.iter().rev() {
if entry.contains(&name.to_lowercase()) {
if entry.contains(name) {
return true;
}
}
Expand All @@ -77,7 +111,7 @@ impl CteStack {
return ControlFlow::Break(Error::InternalError("CTE stack is empty".into()));
};
for cte in ctes {
entry.insert(cte.alias.name.value.to_lowercase());
entry.insert(TableName::from(&cte.alias.name));
}
ControlFlow::Continue(())
}
Expand Down Expand Up @@ -254,20 +288,20 @@ impl VisitorMut for Validator<'_> {
return ControlFlow::Break(Error::NoQualifiedTables(name.to_string()));
}
let table_name = match &name.0[0] {
ObjectNamePart::Identifier(ident) => &ident.value,
ObjectNamePart::Identifier(ident) => TableName::from(ident),
ObjectNamePart::Function(_) => {
return ControlFlow::Break(Error::NoQualifiedTables(name.to_string()));
}
};

// CTES override subgraph tables
if self.ctes.contains(&table_name.to_lowercase()) && args.is_none() {
if self.ctes.contains(&table_name) && args.is_none() {
return ControlFlow::Continue(());
}

let table = match (self.layout.table(table_name), args) {
let table = match (self.layout.table(table_name.as_str()), args) {
(None, None) => {
return ControlFlow::Break(Error::UnknownTable(table_name.clone()));
return ControlFlow::Break(Error::UnknownTable(table_name.to_string()));
}
(Some(_), Some(_)) => {
// Table exists but has args, must be a function
Expand All @@ -278,7 +312,7 @@ impl VisitorMut for Validator<'_> {
// aggregation table in the form <name>(<interval>) or
// must be a function

if !self.layout.has_aggregation(table_name) {
if !self.layout.has_aggregation(table_name.as_str()) {
// Not an aggregation, must be a function
return self.validate_function_name(&name);
}
Expand All @@ -287,23 +321,24 @@ impl VisitorMut for Validator<'_> {
if settings.is_some() {
// We do not support settings on aggregation tables
return ControlFlow::Break(Error::InvalidAggregationSyntax(
table_name.clone(),
table_name.to_string(),
));
}
let Some(intv) = extract_string_arg(args) else {
// Looks like an aggregation, but argument is not a single string
return ControlFlow::Break(Error::InvalidAggregationSyntax(
table_name.clone(),
table_name.to_string(),
));
};
let Some(intv) = intv.parse::<AggregationInterval>().ok() else {
return ControlFlow::Break(Error::UnknownAggregationInterval(
table_name.clone(),
table_name.to_string(),
intv,
));
};

let Some(table) = self.layout.aggregation_table(table_name, intv) else {
let Some(table) = self.layout.aggregation_table(table_name.as_str(), intv)
else {
return self.validate_function_name(&name);
};
table
Expand All @@ -312,7 +347,7 @@ impl VisitorMut for Validator<'_> {
if !table.object.is_object_type() {
// Interfaces and aggregations can not be queried
// with the table name directly
return ControlFlow::Break(Error::UnknownTable(table_name.clone()));
return ControlFlow::Break(Error::UnknownTable(table_name.to_string()));
}
table
}
Expand Down
Loading