Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 44 additions & 34 deletions deepnote_toolkit/sql/sql_query_chaining.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import lru_cache

import __main__
import sqlparse
from sqlparse.tokens import Keyword
Expand Down Expand Up @@ -67,34 +69,8 @@ def extract_table_reference_from_token(token):

def extract_table_references(query):
"""Extract table references from SQL query including CTEs and subqueries."""
table_references = set()

try:
parsed = sqlparse.parse(query)
except Exception:
return []

# State to indicate the next token is a potential table name
expect_table = False

for statement in parsed:
# Flattening the statement will let us process tokens in linear sequence meaning we won't have to process groups of tokens (Identifier or IdentifierList)
for token in statement.flatten():
if token.is_whitespace or token.ttype == sqlparse.tokens.Punctuation:
continue

if expect_table:
table_references.update(extract_table_reference_from_token(token))
expect_table = False # reset state after table name is found
continue

if token.ttype is Keyword:
normalized_token = token.normalized.upper()
# Check if token is "FROM" or contains "JOIN"
if normalized_token == "FROM" or "JOIN" in normalized_token:
expect_table = True

return list(table_references)
# Uses tuple for immutability in cache, but returns a list for legacy compatibility
return list(_cached_extract_table_references(query))


def find_query_preview_references(
Expand Down Expand Up @@ -140,13 +116,11 @@ def find_query_preview_references(
# Check if the reference exists in the main module
if hasattr(__main__, table_reference):
variable_name = table_reference
if variable_name in query_preview_references:
# Already processed (no need for id/instance compare since variable name unique in dict)
continue
variable = getattr(__main__, table_reference)
# If it's a QueryPreview object and not already in our list
# Use any() with a generator expression to check if the variable is already in the list
# This avoids using the pandas object in a boolean context
if isinstance(variable, DeepnoteQueryPreview) and not any(
id(variable) == id(ref) for ref in query_preview_references
):
if isinstance(variable, DeepnoteQueryPreview):
# Add it to our list
query_preview_source = variable._deepnote_query
query_preview_references[variable_name] = query_preview_source
Expand Down Expand Up @@ -235,3 +209,39 @@ def unchain_sql_query(query):
cte_sql = "WITH " + ",\n".join(cte_parts)
final_query = f"{cte_sql}\n{query.strip()}"
return final_query


# LRU cache for table reference extraction per-normalized query (covers _extracted flattening work)
@lru_cache(maxsize=64)
def _cached_extract_table_references(query: str):
table_references = set()

try:
parsed = sqlparse.parse(query)
except Exception:
return tuple()

# State to indicate the next token is a potential table name

# State to indicate the next token is a potential table name
expect_table = False

for statement in parsed:
# Flattening the statement will let us process tokens in linear sequence meaning we won't have to process groups of tokens (Identifier or IdentifierList)
for token in statement.flatten():
if token.is_whitespace or token.ttype == sqlparse.tokens.Punctuation:
continue

if expect_table:
table_references.update(extract_table_reference_from_token(token))
expect_table = False # reset state after table name is found
continue

ttype = token.ttype
if ttype is Keyword:
normalized_token = token.normalized.upper()
# Check if token is "FROM" or contains "JOIN"
if normalized_token == "FROM" or "JOIN" in normalized_token:
expect_table = True

return tuple(table_references)
11 changes: 10 additions & 1 deletion deepnote_toolkit/sql/sql_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from functools import lru_cache

import sqlparse


def is_single_select_query(sql_string):
parsed_queries = sqlparse.parse(sql_string)
parsed_queries = _cached_sqlparse_parse(sql_string)
# Check if there is only one query in the string

# Check if there is only one query in the string
if len(parsed_queries) != 1:
return False

# Check if the query is a SELECT statement
return parsed_queries[0].get_type() == "SELECT"


# LRU cache for SQL parsing for up to 64 distinct queries
@lru_cache(maxsize=64)
def _cached_sqlparse_parse(sql_string: str):
return sqlparse.parse(sql_string)