Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion .github/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ schemas
psycopg
html
PostgreSQLLoader
PostgresLoader
api
postgres
postgresql
Expand Down Expand Up @@ -77,6 +78,7 @@ LLM
Ollama
OpenAI
OpenAI's
DockerHub
Dockerhub
FDE
github
Expand All @@ -98,4 +100,21 @@ Sanitization
JOINs
subqueries
subquery
TTL
TTL

config
docstring
dotenv
ESLint
GraphNotFoundError
HSTS
init
InternalError
InvalidArgumentError
Middleware
monorepo
PRs
pylint
pytest
Radix
Zod
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,7 @@ demo_tokens.py
/blob-report/
/playwright/.cache/
/playwright/.auth/
e2e/.auth/
e2e/.auth/
# Build artifacts
clients/python/queryweaver_client.egg-info/
clients/ts/dist/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ with requests.post(url, headers=headers, json={"chat": ["Count orders last week"
continue
obj = json.loads(part)
print('STREAM:', obj)
```

Notes & tips
- Graph IDs are namespaced per-user. When calling the API directly use the plain graph id (the server will namespace by the authenticated user). For uploaded files the `database` field determines the saved graph id.
Expand Down
117 changes: 95 additions & 22 deletions api/loaders/postgres_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import decimal
import logging
from typing import AsyncGenerator, Dict, Any, List, Tuple
from urllib.parse import urlparse, parse_qs, unquote

import psycopg2
from psycopg2 import sql
Expand Down Expand Up @@ -52,7 +53,7 @@ class PostgresLoader(BaseLoader):

@staticmethod
def _execute_sample_query(
cursor, table_name: str, col_name: str, sample_size: int = 3
cursor: Any, table_name: str, col_name: str, sample_size: int = 3
) -> List[Any]:
"""
Execute query to get random sample values for a column.
Expand Down Expand Up @@ -96,39 +97,96 @@ def _serialize_value(value):
return None
return value

@staticmethod
def parse_schema_from_url(connection_url: str) -> str:
"""
Parse the search_path from the connection URL's options parameter.

The options parameter follows PostgreSQL's libpq format:
postgresql://user:pass@host:port/db?options=-csearch_path%3Dschema_name

Args:
connection_url: PostgreSQL connection URL

Returns:
The first schema from search_path, or 'public' if not specified
"""
try:
parsed = urlparse(connection_url)
query_params = parse_qs(parsed.query)

options = query_params.get('options', [])
if not options:
return 'public'

options_str = unquote(options[0])

# Parse -c search_path=value from options
# Format can be: -csearch_path=schema or -c search_path=schema
# Match comma-separated schema tokens (supports spaces after commas).
match = re.search(r'-c\s*search_path\s*=\s*([^\s,]+(?:\s*,\s*[^\s,]+)*)', options_str, re.IGNORECASE)
if match:
search_path = match.group(1)
schemas = search_path.split(',')
for s in schemas:
s = s.strip().strip('"\'')
if s and s != '$user':
return s
return 'public'

return 'public'

except Exception: # pylint: disable=broad-exception-caught
return 'public'

@staticmethod
async def load(prefix: str, connection_url: str) -> AsyncGenerator[tuple[bool, str], None]:
"""
Load the graph data from a PostgreSQL database into the graph database.

Args:
connection_url: PostgreSQL connection URL in format:
postgresql://username:password@host:port/database
postgresql://username:password@host:port/database
Optionally with schema via options parameter:
postgresql://...?options=-csearch_path%3Dschema_name

Returns:
Tuple[bool, str]: Success status and message
"""
conn = None
cursor = None
try:
# Parse schema from connection URL (defaults to 'public')
schema = PostgresLoader.parse_schema_from_url(connection_url)

# Connect to PostgreSQL database
conn = psycopg2.connect(connection_url)
cursor = conn.cursor()

# Set the session search_path to the parsed schema so unqualified
# table references (e.g. in sample queries) resolve correctly.
cursor.execute(
sql.SQL("SET search_path TO {}").format(sql.Identifier(schema))
)
Comment on lines +166 to +170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

Does a PostgreSQL/libpq connection URI with options=-csearch_path=tenant,extensionsautomatically set the sessionsearch_path, and would executing SET search_path TO tenantafterwards discard theextensions entry?

💡 Result:

Yes.

  • In a libpq connection URI, options=-csearch_path=tenant,extensions is sent to the server “at connection start”, and -c name=value sets the session’s value of that run-time parameter—so it sets the session search_path automatically when the connection is established. [1]
  • Running SET search_path TO tenant afterwards replaces the search_path with just tenant (PostgreSQL’s docs explicitly note that SET search_path TO myschema leaves you without access to schemas that were previously in the path, like public). So it would drop the extensions entry (with pg_catalog still effectively searched). [2]

If you want to keep extensions available, use SET search_path TO tenant, extensions (or include whatever other schemas you need) rather than SET search_path TO tenant. [2]

Sources:
[1] PostgreSQL libpq options parameter docs (sets session GUCs via -c). (postgresql.org)
[2] PostgreSQL “Schemas” docs (SET search_path TO ... examples; SET search_path TO myschema removes access to other schemas; pg_catalog is always effectively in the path). (postgresql.org)

Citations:


Don't collapse the caller's search_path to a single schema.

When psycopg2.connect(connection_url) receives options=-csearch_path=tenant,extensions, the session search_path is automatically set to include both schemas. Executing SET search_path TO schema afterwards replaces the entire search_path with only that single schema, discarding any additional entries like extension schemas or $user references. This breaks the connection's schema resolution for sample queries and diverges from the user's original URL configuration. Either preserve the full search_path when setting it, or use schema-qualified table references instead of mutating the session.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@api/loaders/postgres_loader.py` around lines 166 - 170, The code currently
replaces the entire session search_path by calling cursor.execute(sql.SQL("SET
search_path TO {}").format(sql.Identifier(schema))), which discards any existing
entries set via the connection options; instead read the current session
search_path (e.g. with "SHOW search_path"), parse it, and if the target schema
(variable schema) is not already present, prepend or append it to the existing
list and then call SET search_path TO <full_list>; alternatively avoid mutating
the session at all and use schema-qualified table references when constructing
sample queries. Update the logic around cursor.execute / sql.SQL("SET
search_path TO {}") in postgres_loader.py to preserve and reuse the existing
search_path rather than collapsing it to a single schema.


# Extract database name from connection URL
db_name = connection_url.split('/')[-1]
if '?' in db_name:
db_name = db_name.split('?')[0]

# Get all table information
yield True, "Extracting table information..."
entities = PostgresLoader.extract_tables_info(cursor)
entities = PostgresLoader.extract_tables_info(cursor, schema)

yield True, "Extracting relationship information..."
# Get all relationship information
relationships = PostgresLoader.extract_relationships(cursor)
relationships = PostgresLoader.extract_relationships(cursor, schema)

# Close database connection
# Close database connection before graph loading
cursor.close()
cursor = None
conn.close()
conn = None

yield True, "Loading data into graph..."
# Load data into graph
Expand All @@ -144,46 +202,53 @@ async def load(prefix: str, connection_url: str) -> AsyncGenerator[tuple[bool, s
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error loading PostgreSQL schema: %s", e)
yield False, "Failed to load PostgreSQL database schema"
finally:
if cursor is not None:
cursor.close()
if conn is not None:
conn.close()

@staticmethod
def extract_tables_info(cursor) -> Dict[str, Any]:
def extract_tables_info(cursor: Any, schema: str = 'public') -> Dict[str, Any]:
"""
Extract table and column information from PostgreSQL database.

Args:
cursor: Database cursor
schema: Database schema to extract tables from (default: 'public')

Returns:
Dict containing table information
"""
entities = {}

# Get all tables in public schema
# Get all tables in the specified schema
cursor.execute("""
SELECT table_name, table_comment
FROM information_schema.tables t
LEFT JOIN (
SELECT schemaname, tablename, description as table_comment
FROM pg_tables pt
JOIN pg_class pc ON pc.relname = pt.tablename
JOIN pg_namespace pn ON pn.oid = pc.relnamespace AND pn.nspname = pt.schemaname
JOIN pg_description pd ON pd.objoid = pc.oid AND pd.objsubid = 0
WHERE pt.schemaname = 'public'
WHERE pt.schemaname = %s
) tc ON tc.tablename = t.table_name
WHERE t.table_schema = 'public'
WHERE t.table_schema = %s
AND t.table_type = 'BASE TABLE'
ORDER BY t.table_name;
""")
""", (schema, schema))

tables = cursor.fetchall()

for table_name, table_comment in tqdm.tqdm(tables, desc="Extracting table information"):
table_name = table_name.strip()

# Get column information for this table
columns_info = PostgresLoader.extract_columns_info(cursor, table_name)
columns_info = PostgresLoader.extract_columns_info(cursor, table_name, schema)

# Get foreign keys for this table
foreign_keys = PostgresLoader.extract_foreign_keys(cursor, table_name)
foreign_keys = PostgresLoader.extract_foreign_keys(cursor, table_name, schema)

# Generate table description
table_description = table_comment if table_comment else f"Table: {table_name}"
Expand All @@ -201,13 +266,14 @@ def extract_tables_info(cursor) -> Dict[str, Any]:
return entities

@staticmethod
def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]:
def extract_columns_info(cursor: Any, table_name: str, schema: str = 'public') -> Dict[str, Any]:
"""
Extract column information for a specific table.

Args:
cursor: Database cursor
table_name: Name of the table
schema: Database schema (default: 'public')

Returns:
Dict containing column information
Expand All @@ -230,24 +296,29 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]:
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage ku
ON tc.constraint_name = ku.constraint_name
AND tc.constraint_schema = ku.constraint_schema
WHERE tc.table_name = %s
AND tc.table_schema = %s
AND tc.constraint_type = 'PRIMARY KEY'
) pk ON pk.column_name = c.column_name
LEFT JOIN (
SELECT ku.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage ku
ON tc.constraint_name = ku.constraint_name
AND tc.constraint_schema = ku.constraint_schema
WHERE tc.table_name = %s
AND tc.table_schema = %s
AND tc.constraint_type = 'FOREIGN KEY'
) fk ON fk.column_name = c.column_name
LEFT JOIN pg_class pc ON pc.relname = c.table_name
LEFT JOIN pg_namespace pn ON pn.nspname = c.table_schema
LEFT JOIN pg_class pc ON pc.relname = c.table_name AND pc.relnamespace = pn.oid
LEFT JOIN pg_attribute pa ON pa.attrelid = pc.oid AND pa.attname = c.column_name
LEFT JOIN pg_description pgd ON pgd.objoid = pc.oid AND pgd.objsubid = pa.attnum
WHERE c.table_name = %s
AND c.table_schema = 'public'
AND c.table_schema = %s
ORDER BY c.ordinal_position;
""", (table_name, table_name, table_name))
""", (table_name, schema, table_name, schema, table_name, schema))

columns = cursor.fetchall()
columns_info = {}
Expand Down Expand Up @@ -289,13 +360,14 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]:
return columns_info

@staticmethod
def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]:
def extract_foreign_keys(cursor: Any, table_name: str, schema: str = 'public') -> List[Dict[str, str]]:
"""
Extract foreign key information for a specific table.

Args:
cursor: Database cursor
table_name: Name of the table
schema: Database schema (default: 'public')

Returns:
List of foreign key dictionaries
Expand All @@ -315,8 +387,8 @@ def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]:
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_name = %s
AND tc.table_schema = 'public';
""", (table_name,))
AND tc.table_schema = %s;
""", (table_name, schema))

foreign_keys = []
for constraint_name, column_name, foreign_table, foreign_column in cursor.fetchall():
Expand All @@ -330,12 +402,13 @@ def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]:
return foreign_keys

@staticmethod
def extract_relationships(cursor) -> Dict[str, List[Dict[str, str]]]:
def extract_relationships(cursor: Any, schema: str = 'public') -> Dict[str, List[Dict[str, str]]]:
"""
Extract all relationship information from the database.

Args:
cursor: Database cursor
schema: Database schema (default: 'public')

Returns:
Dict containing relationship information
Expand All @@ -355,9 +428,9 @@ def extract_relationships(cursor) -> Dict[str, List[Dict[str, str]]]:
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = 'public'
AND tc.table_schema = %s
ORDER BY tc.table_name, tc.constraint_name;
""")
""", (schema,))

relationships = {}
for (table_name, constraint_name, column_name,
Expand Down
Loading