-
Notifications
You must be signed in to change notification settings - Fork 89
feat: add support for postgres schema selection #475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4b50c39
f44647f
e21496a
7cb758c
9bc3dee
3075e95
5e31f3c
a12fa88
6e33932
0fc1b73
4d7914b
4987dec
04f5728
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| import decimal | ||
| import logging | ||
| from typing import AsyncGenerator, Dict, Any, List, Tuple | ||
| from urllib.parse import urlparse, parse_qs, unquote | ||
|
|
||
| import psycopg2 | ||
| from psycopg2 import sql | ||
|
|
@@ -52,7 +53,7 @@ class PostgresLoader(BaseLoader): | |
|
|
||
| @staticmethod | ||
| def _execute_sample_query( | ||
| cursor, table_name: str, col_name: str, sample_size: int = 3 | ||
| cursor: Any, table_name: str, col_name: str, sample_size: int = 3 | ||
| ) -> List[Any]: | ||
| """ | ||
| Execute query to get random sample values for a column. | ||
|
|
@@ -96,39 +97,96 @@ def _serialize_value(value): | |
| return None | ||
| return value | ||
|
|
||
| @staticmethod | ||
| def parse_schema_from_url(connection_url: str) -> str: | ||
| """ | ||
| Parse the search_path from the connection URL's options parameter. | ||
|
|
||
| The options parameter follows PostgreSQL's libpq format: | ||
| postgresql://user:pass@host:port/db?options=-csearch_path%3Dschema_name | ||
|
|
||
| Args: | ||
| connection_url: PostgreSQL connection URL | ||
|
|
||
| Returns: | ||
| The first schema from search_path, or 'public' if not specified | ||
| """ | ||
| try: | ||
| parsed = urlparse(connection_url) | ||
| query_params = parse_qs(parsed.query) | ||
|
|
||
| options = query_params.get('options', []) | ||
| if not options: | ||
| return 'public' | ||
|
|
||
| options_str = unquote(options[0]) | ||
|
|
||
| # Parse -c search_path=value from options | ||
| # Format can be: -csearch_path=schema or -c search_path=schema | ||
| # Match comma-separated schema tokens (supports spaces after commas). | ||
| match = re.search(r'-c\s*search_path\s*=\s*([^\s,]+(?:\s*,\s*[^\s,]+)*)', options_str, re.IGNORECASE) | ||
| if match: | ||
| search_path = match.group(1) | ||
| schemas = search_path.split(',') | ||
| for s in schemas: | ||
| s = s.strip().strip('"\'') | ||
| if s and s != '$user': | ||
gkorland marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return s | ||
| return 'public' | ||
|
|
||
| return 'public' | ||
|
|
||
| except Exception: # pylint: disable=broad-exception-caught | ||
| return 'public' | ||
|
|
||
| @staticmethod | ||
| async def load(prefix: str, connection_url: str) -> AsyncGenerator[tuple[bool, str], None]: | ||
| """ | ||
| Load the graph data from a PostgreSQL database into the graph database. | ||
|
|
||
| Args: | ||
| connection_url: PostgreSQL connection URL in format: | ||
| postgresql://username:password@host:port/database | ||
| postgresql://username:password@host:port/database | ||
| Optionally with schema via options parameter: | ||
| postgresql://...?options=-csearch_path%3Dschema_name | ||
|
|
||
| Returns: | ||
| Tuple[bool, str]: Success status and message | ||
| """ | ||
| conn = None | ||
| cursor = None | ||
| try: | ||
| # Parse schema from connection URL (defaults to 'public') | ||
| schema = PostgresLoader.parse_schema_from_url(connection_url) | ||
|
|
||
| # Connect to PostgreSQL database | ||
| conn = psycopg2.connect(connection_url) | ||
| cursor = conn.cursor() | ||
gkorland marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Set the session search_path to the parsed schema so unqualified | ||
| # table references (e.g. in sample queries) resolve correctly. | ||
| cursor.execute( | ||
| sql.SQL("SET search_path TO {}").format(sql.Identifier(schema)) | ||
| ) | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+166
to
+170
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: Yes.
If you want to keep extensions available, use Sources: Citations:
Don't collapse the caller's When 🤖 Prompt for AI Agents |
||
|
|
||
| # Extract database name from connection URL | ||
| db_name = connection_url.split('/')[-1] | ||
| if '?' in db_name: | ||
| db_name = db_name.split('?')[0] | ||
|
|
||
| # Get all table information | ||
| yield True, "Extracting table information..." | ||
| entities = PostgresLoader.extract_tables_info(cursor) | ||
| entities = PostgresLoader.extract_tables_info(cursor, schema) | ||
|
|
||
| yield True, "Extracting relationship information..." | ||
| # Get all relationship information | ||
| relationships = PostgresLoader.extract_relationships(cursor) | ||
| relationships = PostgresLoader.extract_relationships(cursor, schema) | ||
|
|
||
| # Close database connection | ||
| # Close database connection before graph loading | ||
| cursor.close() | ||
| cursor = None | ||
| conn.close() | ||
| conn = None | ||
|
|
||
| yield True, "Loading data into graph..." | ||
| # Load data into graph | ||
|
|
@@ -144,46 +202,53 @@ async def load(prefix: str, connection_url: str) -> AsyncGenerator[tuple[bool, s | |
| except Exception as e: # pylint: disable=broad-exception-caught | ||
| logging.error("Error loading PostgreSQL schema: %s", e) | ||
| yield False, "Failed to load PostgreSQL database schema" | ||
| finally: | ||
| if cursor is not None: | ||
| cursor.close() | ||
| if conn is not None: | ||
| conn.close() | ||
|
|
||
| @staticmethod | ||
| def extract_tables_info(cursor) -> Dict[str, Any]: | ||
| def extract_tables_info(cursor: Any, schema: str = 'public') -> Dict[str, Any]: | ||
| """ | ||
| Extract table and column information from PostgreSQL database. | ||
|
|
||
| Args: | ||
| cursor: Database cursor | ||
| schema: Database schema to extract tables from (default: 'public') | ||
|
|
||
| Returns: | ||
| Dict containing table information | ||
| """ | ||
| entities = {} | ||
|
|
||
| # Get all tables in public schema | ||
| # Get all tables in the specified schema | ||
| cursor.execute(""" | ||
| SELECT table_name, table_comment | ||
| FROM information_schema.tables t | ||
| LEFT JOIN ( | ||
| SELECT schemaname, tablename, description as table_comment | ||
| FROM pg_tables pt | ||
| JOIN pg_class pc ON pc.relname = pt.tablename | ||
| JOIN pg_namespace pn ON pn.oid = pc.relnamespace AND pn.nspname = pt.schemaname | ||
| JOIN pg_description pd ON pd.objoid = pc.oid AND pd.objsubid = 0 | ||
| WHERE pt.schemaname = 'public' | ||
| WHERE pt.schemaname = %s | ||
| ) tc ON tc.tablename = t.table_name | ||
| WHERE t.table_schema = 'public' | ||
| WHERE t.table_schema = %s | ||
| AND t.table_type = 'BASE TABLE' | ||
| ORDER BY t.table_name; | ||
| """) | ||
| """, (schema, schema)) | ||
|
|
||
| tables = cursor.fetchall() | ||
|
|
||
| for table_name, table_comment in tqdm.tqdm(tables, desc="Extracting table information"): | ||
| table_name = table_name.strip() | ||
|
|
||
| # Get column information for this table | ||
| columns_info = PostgresLoader.extract_columns_info(cursor, table_name) | ||
| columns_info = PostgresLoader.extract_columns_info(cursor, table_name, schema) | ||
|
|
||
| # Get foreign keys for this table | ||
| foreign_keys = PostgresLoader.extract_foreign_keys(cursor, table_name) | ||
| foreign_keys = PostgresLoader.extract_foreign_keys(cursor, table_name, schema) | ||
|
|
||
| # Generate table description | ||
| table_description = table_comment if table_comment else f"Table: {table_name}" | ||
|
|
@@ -201,13 +266,14 @@ def extract_tables_info(cursor) -> Dict[str, Any]: | |
| return entities | ||
|
|
||
| @staticmethod | ||
| def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]: | ||
| def extract_columns_info(cursor: Any, table_name: str, schema: str = 'public') -> Dict[str, Any]: | ||
| """ | ||
| Extract column information for a specific table. | ||
|
|
||
| Args: | ||
| cursor: Database cursor | ||
| table_name: Name of the table | ||
| schema: Database schema (default: 'public') | ||
|
|
||
| Returns: | ||
| Dict containing column information | ||
|
|
@@ -230,24 +296,29 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]: | |
| FROM information_schema.table_constraints tc | ||
| JOIN information_schema.key_column_usage ku | ||
| ON tc.constraint_name = ku.constraint_name | ||
| AND tc.constraint_schema = ku.constraint_schema | ||
| WHERE tc.table_name = %s | ||
| AND tc.table_schema = %s | ||
| AND tc.constraint_type = 'PRIMARY KEY' | ||
| ) pk ON pk.column_name = c.column_name | ||
| LEFT JOIN ( | ||
| SELECT ku.column_name | ||
| FROM information_schema.table_constraints tc | ||
| JOIN information_schema.key_column_usage ku | ||
| ON tc.constraint_name = ku.constraint_name | ||
| AND tc.constraint_schema = ku.constraint_schema | ||
| WHERE tc.table_name = %s | ||
| AND tc.table_schema = %s | ||
gkorland marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| AND tc.constraint_type = 'FOREIGN KEY' | ||
| ) fk ON fk.column_name = c.column_name | ||
| LEFT JOIN pg_class pc ON pc.relname = c.table_name | ||
| LEFT JOIN pg_namespace pn ON pn.nspname = c.table_schema | ||
| LEFT JOIN pg_class pc ON pc.relname = c.table_name AND pc.relnamespace = pn.oid | ||
| LEFT JOIN pg_attribute pa ON pa.attrelid = pc.oid AND pa.attname = c.column_name | ||
| LEFT JOIN pg_description pgd ON pgd.objoid = pc.oid AND pgd.objsubid = pa.attnum | ||
| WHERE c.table_name = %s | ||
| AND c.table_schema = 'public' | ||
| AND c.table_schema = %s | ||
| ORDER BY c.ordinal_position; | ||
| """, (table_name, table_name, table_name)) | ||
| """, (table_name, schema, table_name, schema, table_name, schema)) | ||
|
|
||
| columns = cursor.fetchall() | ||
| columns_info = {} | ||
|
|
@@ -289,13 +360,14 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]: | |
| return columns_info | ||
|
|
||
| @staticmethod | ||
| def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]: | ||
| def extract_foreign_keys(cursor: Any, table_name: str, schema: str = 'public') -> List[Dict[str, str]]: | ||
| """ | ||
| Extract foreign key information for a specific table. | ||
|
|
||
| Args: | ||
| cursor: Database cursor | ||
| table_name: Name of the table | ||
| schema: Database schema (default: 'public') | ||
|
|
||
| Returns: | ||
| List of foreign key dictionaries | ||
|
|
@@ -315,8 +387,8 @@ def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]: | |
| AND ccu.table_schema = tc.table_schema | ||
| WHERE tc.constraint_type = 'FOREIGN KEY' | ||
| AND tc.table_name = %s | ||
| AND tc.table_schema = 'public'; | ||
| """, (table_name,)) | ||
| AND tc.table_schema = %s; | ||
| """, (table_name, schema)) | ||
|
|
||
| foreign_keys = [] | ||
| for constraint_name, column_name, foreign_table, foreign_column in cursor.fetchall(): | ||
|
|
@@ -330,12 +402,13 @@ def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]: | |
| return foreign_keys | ||
|
|
||
| @staticmethod | ||
| def extract_relationships(cursor) -> Dict[str, List[Dict[str, str]]]: | ||
| def extract_relationships(cursor: Any, schema: str = 'public') -> Dict[str, List[Dict[str, str]]]: | ||
| """ | ||
| Extract all relationship information from the database. | ||
|
|
||
| Args: | ||
| cursor: Database cursor | ||
| schema: Database schema (default: 'public') | ||
|
|
||
| Returns: | ||
| Dict containing relationship information | ||
|
|
@@ -355,9 +428,9 @@ def extract_relationships(cursor) -> Dict[str, List[Dict[str, str]]]: | |
| ON ccu.constraint_name = tc.constraint_name | ||
| AND ccu.table_schema = tc.table_schema | ||
| WHERE tc.constraint_type = 'FOREIGN KEY' | ||
| AND tc.table_schema = 'public' | ||
| AND tc.table_schema = %s | ||
| ORDER BY tc.table_name, tc.constraint_name; | ||
| """) | ||
| """, (schema,)) | ||
|
|
||
| relationships = {} | ||
| for (table_name, constraint_name, column_name, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.