1+ #!/usr/bin/env python3
2+ """
3+ Database Audit Middleware
4+
5+ This middleware intercepts and logs all database operations to provide
6+ a complete audit trail of real SQL commands vs mocked operations.
7+
8+ Integrates with IRIS connection managers to capture actual database activity.
9+ """
10+
11+ import logging
12+ import time
13+ import inspect
14+ from typing import Any , List , Dict , Optional , Callable , Union
15+ from functools import wraps
16+
17+ from .sql_audit_logger import get_sql_audit_logger , log_sql_execution
18+
19+ logger = logging .getLogger (__name__ )
20+
21+
22+ class AuditableCursor :
23+ """
24+ Wrapper for database cursor that logs all SQL operations.
25+
26+ This class intercepts execute(), fetchall(), fetchone(), etc. to provide
27+ complete audit trail of actual database operations.
28+ """
29+
30+ def __init__ (self , original_cursor , connection_type : str = "unknown" ):
31+ self .original_cursor = original_cursor
32+ self .connection_type = connection_type
33+ self .audit_logger = get_sql_audit_logger ()
34+
35+ # Track current operation for correlation
36+ self ._current_operation_id = None
37+ self ._current_sql = None
38+ self ._current_params = None
39+ self ._operation_start_time = None
40+
41+ def execute (self , sql : str , parameters : Any = None ) -> Any :
42+ """Execute SQL with audit logging."""
43+ self ._operation_start_time = time .time ()
44+ self ._current_sql = sql
45+ self ._current_params = parameters or []
46+
47+ # Log the SQL operation start
48+ self ._current_operation_id = self .audit_logger .log_sql_operation (
49+ sql_statement = sql ,
50+ parameters = self ._current_params ,
51+ )
52+
53+ logger .debug (f"🔴 REAL SQL EXECUTION [{ self ._current_operation_id } ]: { sql [:100 ]} ..." )
54+
55+ try :
56+ # Execute the actual SQL - handle both parameterized and non-parameterized calls
57+ if parameters is None :
58+ result = self .original_cursor .execute (sql )
59+ else :
60+ result = self .original_cursor .execute (sql , parameters )
61+
62+ # Log successful execution
63+ execution_time = (time .time () - self ._operation_start_time ) * 1000
64+ self .audit_logger .log_sql_operation (
65+ sql_statement = sql ,
66+ parameters = self ._current_params ,
67+ execution_time_ms = execution_time ,
68+ rows_affected = getattr (self .original_cursor , 'rowcount' , None )
69+ )
70+
71+ return result
72+
73+ except Exception as e :
74+ # Log failed execution
75+ execution_time = (time .time () - self ._operation_start_time ) * 1000
76+ self .audit_logger .log_sql_operation (
77+ sql_statement = sql ,
78+ parameters = self ._current_params ,
79+ execution_time_ms = execution_time ,
80+ error = str (e )
81+ )
82+
83+ logger .error (f"❌ SQL EXECUTION FAILED [{ self ._current_operation_id } ]: { e } " )
84+ raise
85+
86+ def fetchall (self ) -> List [Any ]:
87+ """Fetch all results with audit logging."""
88+ try :
89+ results = self .original_cursor .fetchall ()
90+
91+ # Update the operation log with result count
92+ if self ._current_operation_id :
93+ execution_time = (time .time () - self ._operation_start_time ) * 1000
94+ self .audit_logger .log_sql_operation (
95+ sql_statement = self ._current_sql ,
96+ parameters = self ._current_params ,
97+ execution_time_ms = execution_time ,
98+ result_count = len (results ) if results else 0
99+ )
100+
101+ logger .debug (f"🔴 REAL SQL FETCHALL [{ self ._current_operation_id } ]: { len (results ) if results else 0 } rows" )
102+ return results
103+
104+ except Exception as e :
105+ logger .error (f"❌ SQL FETCHALL FAILED [{ self ._current_operation_id } ]: { e } " )
106+ raise
107+
108+ def fetchone (self ) -> Any :
109+ """Fetch one result with audit logging."""
110+ try :
111+ result = self .original_cursor .fetchone ()
112+
113+ # Update the operation log
114+ if self ._current_operation_id :
115+ execution_time = (time .time () - self ._operation_start_time ) * 1000
116+ self .audit_logger .log_sql_operation (
117+ sql_statement = self ._current_sql ,
118+ parameters = self ._current_params ,
119+ execution_time_ms = execution_time ,
120+ result_count = 1 if result else 0
121+ )
122+
123+ logger .debug (f"🔴 REAL SQL FETCHONE [{ self ._current_operation_id } ]: { '1 row' if result else 'no rows' } " )
124+ return result
125+
126+ except Exception as e :
127+ logger .error (f"❌ SQL FETCHONE FAILED [{ self ._current_operation_id } ]: { e } " )
128+ raise
129+
130+ def fetchmany (self , size : int = None ) -> List [Any ]:
131+ """Fetch many results with audit logging."""
132+ try :
133+ results = self .original_cursor .fetchmany (size )
134+
135+ # Update the operation log
136+ if self ._current_operation_id :
137+ execution_time = (time .time () - self ._operation_start_time ) * 1000
138+ self .audit_logger .log_sql_operation (
139+ sql_statement = self ._current_sql ,
140+ parameters = self ._current_params ,
141+ execution_time_ms = execution_time ,
142+ result_count = len (results ) if results else 0
143+ )
144+
145+ logger .debug (f"🔴 REAL SQL FETCHMANY [{ self ._current_operation_id } ]: { len (results ) if results else 0 } rows" )
146+ return results
147+
148+ except Exception as e :
149+ logger .error (f"❌ SQL FETCHMANY FAILED [{ self ._current_operation_id } ]: { e } " )
150+ raise
151+
152+ def close (self ):
153+ """Close cursor with audit logging."""
154+ logger .debug (f"🔴 REAL SQL CURSOR CLOSE [{ self ._current_operation_id } ]" )
155+ return self .original_cursor .close ()
156+
157+ def __getattr__ (self , name ):
158+ """Delegate other methods to the original cursor."""
159+ return getattr (self .original_cursor , name )
160+
161+ def __enter__ (self ):
162+ return self
163+
164+ def __exit__ (self , exc_type , exc_val , exc_tb ):
165+ self .close ()
166+
167+
168+ class AuditableConnection :
169+ """
170+ Wrapper for database connection that provides auditable cursors.
171+ """
172+
173+ def __init__ (self , original_connection , connection_type : str = "IRIS" ):
174+ self .original_connection = original_connection
175+ self .connection_type = connection_type
176+
177+ logger .info (f"🔴 REAL DATABASE CONNECTION CREATED: { connection_type } " )
178+
179+ def cursor (self ) -> AuditableCursor :
180+ """Create an auditable cursor."""
181+ original_cursor = self .original_connection .cursor ()
182+ return AuditableCursor (original_cursor , self .connection_type )
183+
184+ def commit (self ):
185+ """Commit transaction with audit logging."""
186+ logger .info (f"🔴 REAL DATABASE COMMIT: { self .connection_type } " )
187+ return self .original_connection .commit ()
188+
189+ def rollback (self ):
190+ """Rollback transaction with audit logging."""
191+ logger .warning (f"🔴 REAL DATABASE ROLLBACK: { self .connection_type } " )
192+ return self .original_connection .rollback ()
193+
194+ def close (self ):
195+ """Close connection with audit logging."""
196+ logger .info (f"🔴 REAL DATABASE CONNECTION CLOSED: { self .connection_type } " )
197+ return self .original_connection .close ()
198+
199+ def __getattr__ (self , name ):
200+ """Delegate other methods to the original connection."""
201+ return getattr (self .original_connection , name )
202+
203+ def __enter__ (self ):
204+ return self
205+
206+ def __exit__ (self , exc_type , exc_val , exc_tb ):
207+ self .close ()
208+
209+
210+ def audit_database_connection (connection_factory : Callable , connection_type : str = "IRIS" ):
211+ """
212+ Decorator to wrap connection factory functions with audit logging.
213+
214+ Usage:
215+ @audit_database_connection
216+ def get_iris_connection():
217+ return iris.connect(...)
218+ """
219+ @wraps (connection_factory )
220+ def wrapper (* args , ** kwargs ):
221+ # Get the original connection
222+ original_connection = connection_factory (* args , ** kwargs )
223+
224+ # Wrap it with auditing
225+ auditable_connection = AuditableConnection (original_connection , connection_type )
226+
227+ return auditable_connection
228+
229+ return wrapper
230+
231+
232+ def patch_iris_connection_manager ():
233+ """
234+ Monkey patch the IRIS connection manager to add audit logging.
235+
236+ This should be called at the start of tests to ensure all database
237+ operations are logged.
238+ """
239+ try :
240+ # Patch the main connection function used by ConnectionManager
241+ from common .iris_dbapi_connector import get_iris_dbapi_connection as original_dbapi_connection
242+
243+ # Create auditable version for DBAPI
244+ @audit_database_connection
245+ def auditable_dbapi_connection (* args , ** kwargs ):
246+ return original_dbapi_connection (* args , ** kwargs )
247+
248+ # Monkey patch the DBAPI connector module (used by ConnectionManager)
249+ import common .iris_dbapi_connector
250+ common .iris_dbapi_connector .get_iris_dbapi_connection = auditable_dbapi_connection
251+
252+ # Also patch the general connection manager for backward compatibility
253+ from common .iris_connection_manager import get_iris_connection as original_get_iris_connection
254+
255+ # Create auditable version
256+ @audit_database_connection
257+ def auditable_get_iris_connection (* args , ** kwargs ):
258+ return original_get_iris_connection (* args , ** kwargs )
259+
260+ # Monkey patch the module
261+ import common .iris_connection_manager
262+ common .iris_connection_manager .get_iris_connection = auditable_get_iris_connection
263+
264+ logger .info ("✅ IRIS connection manager patched for SQL audit logging" )
265+
266+ except ImportError as e :
267+ logger .warning (f"Could not patch IRIS connection manager: { e } " )
268+
269+
270+ def mock_operation_tracker (original_method : Callable ):
271+ """
272+ Decorator to track mocked database operations.
273+
274+ This helps distinguish between real and mocked operations in tests.
275+ """
276+ @wraps (original_method )
277+ def wrapper (* args , ** kwargs ):
278+ # Get the mock call info
279+ method_name = original_method .__name__
280+
281+ # Log the mocked operation
282+ audit_logger = get_sql_audit_logger ()
283+ operation_id = audit_logger .log_sql_operation (
284+ sql_statement = f"MOCKED_{ method_name .upper ()} " ,
285+ parameters = list (args [1 :]) if len (args ) > 1 else [],
286+ execution_time_ms = 0.001 , # Mocks are fast
287+ result_count = len (kwargs .get ('return_value' , [])) if 'return_value' in kwargs else None
288+ )
289+
290+ logger .debug (f"🟡 MOCKED OPERATION [{ operation_id } ]: { method_name } " )
291+
292+ # Call the original mocked method
293+ return original_method (* args , ** kwargs )
294+
295+ return wrapper
296+
297+
298+ class DatabaseOperationCounter :
299+ """
300+ Utility to count and categorize database operations during test execution.
301+ """
302+
303+ def __init__ (self ):
304+ self .reset ()
305+
306+ def reset (self ):
307+ """Reset all counters."""
308+ self .real_operations = 0
309+ self .mocked_operations = 0
310+ self .operation_details = []
311+
312+ def count_operations (self , test_name : str = None ) -> Dict [str , Any ]:
313+ """
314+ Count operations from the audit logger for a specific test.
315+
316+ Returns:
317+ Dictionary with operation counts and analysis
318+ """
319+ audit_logger = get_sql_audit_logger ()
320+
321+ if test_name :
322+ operations = audit_logger .get_operations_by_test (test_name )
323+ else :
324+ operations = audit_logger .operations
325+
326+ real_ops = [op for op in operations if op .execution_context == 'real_database' ]
327+ mocked_ops = [op for op in operations if op .execution_context == 'mocked' ]
328+
329+ return {
330+ "total_operations" : len (operations ),
331+ "real_database_operations" : len (real_ops ),
332+ "mocked_operations" : len (mocked_ops ),
333+ "real_operations_detail" : [
334+ {
335+ "operation_id" : op .operation_id ,
336+ "sql" : op .sql_statement [:100 ] + "..." if len (op .sql_statement ) > 100 else op .sql_statement ,
337+ "execution_time_ms" : op .execution_time_ms ,
338+ "result_count" : op .result_count
339+ }
340+ for op in real_ops
341+ ],
342+ "mocked_operations_detail" : [
343+ {
344+ "operation_id" : op .operation_id ,
345+ "sql" : op .sql_statement ,
346+ "test_name" : op .test_name
347+ }
348+ for op in mocked_ops
349+ ],
350+ "test_isolation_score" : len (real_ops ) / max (len (mocked_ops ), 1 ) # Higher is better
351+ }
352+
353+
354+ # Global instance for easy access
355+ operation_counter = DatabaseOperationCounter ()
356+
357+
358+ if __name__ == "__main__" :
359+ # Test the audit middleware
360+ print ("Testing Database Audit Middleware..." )
361+
362+ # Simulate database operations
363+ audit_logger = get_sql_audit_logger ()
364+
365+ with audit_logger .set_context ('real_database' , 'BasicRAG' ):
366+ audit_logger .log_sql_operation (
367+ "SELECT * FROM RAG.SourceDocuments WHERE doc_id = ?" ,
368+ ["test_doc_1" ],
369+ execution_time_ms = 15.3 ,
370+ result_count = 1
371+ )
372+
373+ with audit_logger .set_context ('mocked' , test_name = 'test_basic_functionality' ):
374+ audit_logger .log_sql_operation (
375+ "MOCKED_EXECUTE" ,
376+ ["SELECT * FROM RAG.SourceDocuments" ],
377+ execution_time_ms = 0.001 ,
378+ result_count = 3
379+ )
380+
381+ # Generate analysis
382+ counter = DatabaseOperationCounter ()
383+ analysis = counter .count_operations ()
384+
385+ print (f"Analysis: { analysis } " )
386+ print (f"Real vs Mock ratio: { analysis ['test_isolation_score' ]:.2f} " )
0 commit comments