diff --git a/deepnote_toolkit/sql/sql_query_chaining.py b/deepnote_toolkit/sql/sql_query_chaining.py index ec4a20e..66b39a7 100644 --- a/deepnote_toolkit/sql/sql_query_chaining.py +++ b/deepnote_toolkit/sql/sql_query_chaining.py @@ -1,3 +1,5 @@ +from functools import lru_cache + import __main__ import sqlparse from sqlparse.tokens import Keyword @@ -67,34 +69,8 @@ def extract_table_reference_from_token(token): def extract_table_references(query): """Extract table references from SQL query including CTEs and subqueries.""" - table_references = set() - - try: - parsed = sqlparse.parse(query) - except Exception: - return [] - - # State to indicate the next token is a potential table name - expect_table = False - - for statement in parsed: - # Flattening the statement will let us process tokens in linear sequence meaning we won't have to process groups of tokens (Identifier or IdentifierList) - for token in statement.flatten(): - if token.is_whitespace or token.ttype == sqlparse.tokens.Punctuation: - continue - - if expect_table: - table_references.update(extract_table_reference_from_token(token)) - expect_table = False # reset state after table name is found - continue - - if token.ttype is Keyword: - normalized_token = token.normalized.upper() - # Check if token is "FROM" or contains "JOIN" - if normalized_token == "FROM" or "JOIN" in normalized_token: - expect_table = True - - return list(table_references) + # Uses tuple for immutability in cache, but returns a list for legacy compatibility + return list(_cached_extract_table_references(query)) def find_query_preview_references( @@ -140,13 +116,11 @@ def find_query_preview_references( # Check if the reference exists in the main module if hasattr(__main__, table_reference): variable_name = table_reference + if variable_name in query_preview_references: + # Already processed (no need for id/instance compare since variable name unique in dict) + continue variable = getattr(__main__, table_reference) - # If it's a QueryPreview object and not already in our list - # Use any() with a generator expression to check if the variable is already in the list - # This avoids using the pandas object in a boolean context - if isinstance(variable, DeepnoteQueryPreview) and not any( - id(variable) == id(ref) for ref in query_preview_references - ): + if isinstance(variable, DeepnoteQueryPreview): # Add it to our list query_preview_source = variable._deepnote_query query_preview_references[variable_name] = query_preview_source @@ -235,3 +209,39 @@ def unchain_sql_query(query): cte_sql = "WITH " + ",\n".join(cte_parts) final_query = f"{cte_sql}\n{query.strip()}" return final_query + + +# LRU cache for table reference extraction per-normalized query (covers _extracted flattening work) +@lru_cache(maxsize=64) +def _cached_extract_table_references(query: str): + table_references = set() + + try: + parsed = sqlparse.parse(query) + except Exception: + return tuple() + + # State to indicate the next token is a potential table name + + # State to indicate the next token is a potential table name + expect_table = False + + for statement in parsed: + # Flattening the statement will let us process tokens in linear sequence meaning we won't have to process groups of tokens (Identifier or IdentifierList) + for token in statement.flatten(): + if token.is_whitespace or token.ttype == sqlparse.tokens.Punctuation: + continue + + if expect_table: + table_references.update(extract_table_reference_from_token(token)) + expect_table = False # reset state after table name is found + continue + + ttype = token.ttype + if ttype is Keyword: + normalized_token = token.normalized.upper() + # Check if token is "FROM" or contains "JOIN" + if normalized_token == "FROM" or "JOIN" in normalized_token: + expect_table = True + + return tuple(table_references) diff --git a/deepnote_toolkit/sql/sql_utils.py b/deepnote_toolkit/sql/sql_utils.py index d5e24f8..686512f 100644 --- a/deepnote_toolkit/sql/sql_utils.py +++ b/deepnote_toolkit/sql/sql_utils.py @@ -1,8 +1,11 @@ +from functools import lru_cache + import sqlparse def is_single_select_query(sql_string): - parsed_queries = sqlparse.parse(sql_string) + parsed_queries = _cached_sqlparse_parse(sql_string) + # Check if there is only one query in the string # Check if there is only one query in the string if len(parsed_queries) != 1: @@ -10,3 +13,9 @@ def is_single_select_query(sql_string): # Check if the query is a SELECT statement return parsed_queries[0].get_type() == "SELECT" + + +# LRU cache for SQL parsing for up to 64 distinct queries +@lru_cache(maxsize=64) +def _cached_sqlparse_parse(sql_string: str): + return sqlparse.parse(sql_string)