Skip to content

Commit e1168fe

Browse files
committed
sync: update from internal GitLab repository
Content updated: Files: - docker-compose.yml Directories: - common/ - iris_rag/ - config/ - docs/ - scripts/ - objectscript/ - tests/ Synced at: 2025-08-03 16:11:14
1 parent b2b2965 commit e1168fe

34 files changed

+5430
-482
lines changed
Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Database Audit Middleware
4+
5+
This middleware intercepts and logs all database operations to provide
6+
a complete audit trail of real SQL commands vs mocked operations.
7+
8+
Integrates with IRIS connection managers to capture actual database activity.
9+
"""
10+
11+
import logging
12+
import time
13+
import inspect
14+
from typing import Any, List, Dict, Optional, Callable, Union
15+
from functools import wraps
16+
17+
from .sql_audit_logger import get_sql_audit_logger, log_sql_execution
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
class AuditableCursor:
23+
"""
24+
Wrapper for database cursor that logs all SQL operations.
25+
26+
This class intercepts execute(), fetchall(), fetchone(), etc. to provide
27+
complete audit trail of actual database operations.
28+
"""
29+
30+
def __init__(self, original_cursor, connection_type: str = "unknown"):
31+
self.original_cursor = original_cursor
32+
self.connection_type = connection_type
33+
self.audit_logger = get_sql_audit_logger()
34+
35+
# Track current operation for correlation
36+
self._current_operation_id = None
37+
self._current_sql = None
38+
self._current_params = None
39+
self._operation_start_time = None
40+
41+
def execute(self, sql: str, parameters: Any = None) -> Any:
42+
"""Execute SQL with audit logging."""
43+
self._operation_start_time = time.time()
44+
self._current_sql = sql
45+
self._current_params = parameters or []
46+
47+
# Log the SQL operation start
48+
self._current_operation_id = self.audit_logger.log_sql_operation(
49+
sql_statement=sql,
50+
parameters=self._current_params,
51+
)
52+
53+
logger.debug(f"🔴 REAL SQL EXECUTION [{self._current_operation_id}]: {sql[:100]}...")
54+
55+
try:
56+
# Execute the actual SQL - handle both parameterized and non-parameterized calls
57+
if parameters is None:
58+
result = self.original_cursor.execute(sql)
59+
else:
60+
result = self.original_cursor.execute(sql, parameters)
61+
62+
# Log successful execution
63+
execution_time = (time.time() - self._operation_start_time) * 1000
64+
self.audit_logger.log_sql_operation(
65+
sql_statement=sql,
66+
parameters=self._current_params,
67+
execution_time_ms=execution_time,
68+
rows_affected=getattr(self.original_cursor, 'rowcount', None)
69+
)
70+
71+
return result
72+
73+
except Exception as e:
74+
# Log failed execution
75+
execution_time = (time.time() - self._operation_start_time) * 1000
76+
self.audit_logger.log_sql_operation(
77+
sql_statement=sql,
78+
parameters=self._current_params,
79+
execution_time_ms=execution_time,
80+
error=str(e)
81+
)
82+
83+
logger.error(f"❌ SQL EXECUTION FAILED [{self._current_operation_id}]: {e}")
84+
raise
85+
86+
def fetchall(self) -> List[Any]:
87+
"""Fetch all results with audit logging."""
88+
try:
89+
results = self.original_cursor.fetchall()
90+
91+
# Update the operation log with result count
92+
if self._current_operation_id:
93+
execution_time = (time.time() - self._operation_start_time) * 1000
94+
self.audit_logger.log_sql_operation(
95+
sql_statement=self._current_sql,
96+
parameters=self._current_params,
97+
execution_time_ms=execution_time,
98+
result_count=len(results) if results else 0
99+
)
100+
101+
logger.debug(f"🔴 REAL SQL FETCHALL [{self._current_operation_id}]: {len(results) if results else 0} rows")
102+
return results
103+
104+
except Exception as e:
105+
logger.error(f"❌ SQL FETCHALL FAILED [{self._current_operation_id}]: {e}")
106+
raise
107+
108+
def fetchone(self) -> Any:
109+
"""Fetch one result with audit logging."""
110+
try:
111+
result = self.original_cursor.fetchone()
112+
113+
# Update the operation log
114+
if self._current_operation_id:
115+
execution_time = (time.time() - self._operation_start_time) * 1000
116+
self.audit_logger.log_sql_operation(
117+
sql_statement=self._current_sql,
118+
parameters=self._current_params,
119+
execution_time_ms=execution_time,
120+
result_count=1 if result else 0
121+
)
122+
123+
logger.debug(f"🔴 REAL SQL FETCHONE [{self._current_operation_id}]: {'1 row' if result else 'no rows'}")
124+
return result
125+
126+
except Exception as e:
127+
logger.error(f"❌ SQL FETCHONE FAILED [{self._current_operation_id}]: {e}")
128+
raise
129+
130+
def fetchmany(self, size: int = None) -> List[Any]:
131+
"""Fetch many results with audit logging."""
132+
try:
133+
results = self.original_cursor.fetchmany(size)
134+
135+
# Update the operation log
136+
if self._current_operation_id:
137+
execution_time = (time.time() - self._operation_start_time) * 1000
138+
self.audit_logger.log_sql_operation(
139+
sql_statement=self._current_sql,
140+
parameters=self._current_params,
141+
execution_time_ms=execution_time,
142+
result_count=len(results) if results else 0
143+
)
144+
145+
logger.debug(f"🔴 REAL SQL FETCHMANY [{self._current_operation_id}]: {len(results) if results else 0} rows")
146+
return results
147+
148+
except Exception as e:
149+
logger.error(f"❌ SQL FETCHMANY FAILED [{self._current_operation_id}]: {e}")
150+
raise
151+
152+
def close(self):
153+
"""Close cursor with audit logging."""
154+
logger.debug(f"🔴 REAL SQL CURSOR CLOSE [{self._current_operation_id}]")
155+
return self.original_cursor.close()
156+
157+
def __getattr__(self, name):
158+
"""Delegate other methods to the original cursor."""
159+
return getattr(self.original_cursor, name)
160+
161+
def __enter__(self):
162+
return self
163+
164+
def __exit__(self, exc_type, exc_val, exc_tb):
165+
self.close()
166+
167+
168+
class AuditableConnection:
169+
"""
170+
Wrapper for database connection that provides auditable cursors.
171+
"""
172+
173+
def __init__(self, original_connection, connection_type: str = "IRIS"):
174+
self.original_connection = original_connection
175+
self.connection_type = connection_type
176+
177+
logger.info(f"🔴 REAL DATABASE CONNECTION CREATED: {connection_type}")
178+
179+
def cursor(self) -> AuditableCursor:
180+
"""Create an auditable cursor."""
181+
original_cursor = self.original_connection.cursor()
182+
return AuditableCursor(original_cursor, self.connection_type)
183+
184+
def commit(self):
185+
"""Commit transaction with audit logging."""
186+
logger.info(f"🔴 REAL DATABASE COMMIT: {self.connection_type}")
187+
return self.original_connection.commit()
188+
189+
def rollback(self):
190+
"""Rollback transaction with audit logging."""
191+
logger.warning(f"🔴 REAL DATABASE ROLLBACK: {self.connection_type}")
192+
return self.original_connection.rollback()
193+
194+
def close(self):
195+
"""Close connection with audit logging."""
196+
logger.info(f"🔴 REAL DATABASE CONNECTION CLOSED: {self.connection_type}")
197+
return self.original_connection.close()
198+
199+
def __getattr__(self, name):
200+
"""Delegate other methods to the original connection."""
201+
return getattr(self.original_connection, name)
202+
203+
def __enter__(self):
204+
return self
205+
206+
def __exit__(self, exc_type, exc_val, exc_tb):
207+
self.close()
208+
209+
210+
def audit_database_connection(connection_factory: Callable, connection_type: str = "IRIS"):
211+
"""
212+
Decorator to wrap connection factory functions with audit logging.
213+
214+
Usage:
215+
@audit_database_connection
216+
def get_iris_connection():
217+
return iris.connect(...)
218+
"""
219+
@wraps(connection_factory)
220+
def wrapper(*args, **kwargs):
221+
# Get the original connection
222+
original_connection = connection_factory(*args, **kwargs)
223+
224+
# Wrap it with auditing
225+
auditable_connection = AuditableConnection(original_connection, connection_type)
226+
227+
return auditable_connection
228+
229+
return wrapper
230+
231+
232+
def patch_iris_connection_manager():
233+
"""
234+
Monkey patch the IRIS connection manager to add audit logging.
235+
236+
This should be called at the start of tests to ensure all database
237+
operations are logged.
238+
"""
239+
try:
240+
# Patch the main connection function used by ConnectionManager
241+
from common.iris_dbapi_connector import get_iris_dbapi_connection as original_dbapi_connection
242+
243+
# Create auditable version for DBAPI
244+
@audit_database_connection
245+
def auditable_dbapi_connection(*args, **kwargs):
246+
return original_dbapi_connection(*args, **kwargs)
247+
248+
# Monkey patch the DBAPI connector module (used by ConnectionManager)
249+
import common.iris_dbapi_connector
250+
common.iris_dbapi_connector.get_iris_dbapi_connection = auditable_dbapi_connection
251+
252+
# Also patch the general connection manager for backward compatibility
253+
from common.iris_connection_manager import get_iris_connection as original_get_iris_connection
254+
255+
# Create auditable version
256+
@audit_database_connection
257+
def auditable_get_iris_connection(*args, **kwargs):
258+
return original_get_iris_connection(*args, **kwargs)
259+
260+
# Monkey patch the module
261+
import common.iris_connection_manager
262+
common.iris_connection_manager.get_iris_connection = auditable_get_iris_connection
263+
264+
logger.info("✅ IRIS connection manager patched for SQL audit logging")
265+
266+
except ImportError as e:
267+
logger.warning(f"Could not patch IRIS connection manager: {e}")
268+
269+
270+
def mock_operation_tracker(original_method: Callable):
271+
"""
272+
Decorator to track mocked database operations.
273+
274+
This helps distinguish between real and mocked operations in tests.
275+
"""
276+
@wraps(original_method)
277+
def wrapper(*args, **kwargs):
278+
# Get the mock call info
279+
method_name = original_method.__name__
280+
281+
# Log the mocked operation
282+
audit_logger = get_sql_audit_logger()
283+
operation_id = audit_logger.log_sql_operation(
284+
sql_statement=f"MOCKED_{method_name.upper()}",
285+
parameters=list(args[1:]) if len(args) > 1 else [],
286+
execution_time_ms=0.001, # Mocks are fast
287+
result_count=len(kwargs.get('return_value', [])) if 'return_value' in kwargs else None
288+
)
289+
290+
logger.debug(f"🟡 MOCKED OPERATION [{operation_id}]: {method_name}")
291+
292+
# Call the original mocked method
293+
return original_method(*args, **kwargs)
294+
295+
return wrapper
296+
297+
298+
class DatabaseOperationCounter:
299+
"""
300+
Utility to count and categorize database operations during test execution.
301+
"""
302+
303+
def __init__(self):
304+
self.reset()
305+
306+
def reset(self):
307+
"""Reset all counters."""
308+
self.real_operations = 0
309+
self.mocked_operations = 0
310+
self.operation_details = []
311+
312+
def count_operations(self, test_name: str = None) -> Dict[str, Any]:
313+
"""
314+
Count operations from the audit logger for a specific test.
315+
316+
Returns:
317+
Dictionary with operation counts and analysis
318+
"""
319+
audit_logger = get_sql_audit_logger()
320+
321+
if test_name:
322+
operations = audit_logger.get_operations_by_test(test_name)
323+
else:
324+
operations = audit_logger.operations
325+
326+
real_ops = [op for op in operations if op.execution_context == 'real_database']
327+
mocked_ops = [op for op in operations if op.execution_context == 'mocked']
328+
329+
return {
330+
"total_operations": len(operations),
331+
"real_database_operations": len(real_ops),
332+
"mocked_operations": len(mocked_ops),
333+
"real_operations_detail": [
334+
{
335+
"operation_id": op.operation_id,
336+
"sql": op.sql_statement[:100] + "..." if len(op.sql_statement) > 100 else op.sql_statement,
337+
"execution_time_ms": op.execution_time_ms,
338+
"result_count": op.result_count
339+
}
340+
for op in real_ops
341+
],
342+
"mocked_operations_detail": [
343+
{
344+
"operation_id": op.operation_id,
345+
"sql": op.sql_statement,
346+
"test_name": op.test_name
347+
}
348+
for op in mocked_ops
349+
],
350+
"test_isolation_score": len(real_ops) / max(len(mocked_ops), 1) # Higher is better
351+
}
352+
353+
354+
# Global instance for easy access
355+
operation_counter = DatabaseOperationCounter()
356+
357+
358+
if __name__ == "__main__":
359+
# Test the audit middleware
360+
print("Testing Database Audit Middleware...")
361+
362+
# Simulate database operations
363+
audit_logger = get_sql_audit_logger()
364+
365+
with audit_logger.set_context('real_database', 'BasicRAG'):
366+
audit_logger.log_sql_operation(
367+
"SELECT * FROM RAG.SourceDocuments WHERE doc_id = ?",
368+
["test_doc_1"],
369+
execution_time_ms=15.3,
370+
result_count=1
371+
)
372+
373+
with audit_logger.set_context('mocked', test_name='test_basic_functionality'):
374+
audit_logger.log_sql_operation(
375+
"MOCKED_EXECUTE",
376+
["SELECT * FROM RAG.SourceDocuments"],
377+
execution_time_ms=0.001,
378+
result_count=3
379+
)
380+
381+
# Generate analysis
382+
counter = DatabaseOperationCounter()
383+
analysis = counter.count_operations()
384+
385+
print(f"Analysis: {analysis}")
386+
print(f"Real vs Mock ratio: {analysis['test_isolation_score']:.2f}")

0 commit comments

Comments
 (0)