Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions db/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Unified database connection utilities.
Provides consistent connection management across all database operations.
"""
import os
import sqlite3
from typing import Optional
from contextlib import contextmanager
from utils.logger import get_logger

logger = get_logger(__name__)


def get_db_connection(
db_path: str,
timeout: float = 30.0,
enable_wal: bool = True,
enable_vector: bool = False,
row_factory: bool = True
) -> sqlite3.Connection:
"""
Create a database connection with consistent configuration.

Args:
db_path: Path to the SQLite database file
timeout: Timeout in seconds for waiting on locks (default: 30.0)
enable_wal: Enable Write-Ahead Logging mode (default: True)
enable_vector: Load sqlite-vector extension (default: False)
row_factory: Use sqlite3.Row factory for dict-like access (default: True)

Returns:
sqlite3.Connection object configured for the specified operations

Raises:
RuntimeError: If vector extension fails to load when enable_vector=True
"""
# Create directory if needed
dirname = os.path.dirname(os.path.abspath(db_path))
if dirname and not os.path.isdir(dirname):
os.makedirs(dirname, exist_ok=True)

# Create connection with consistent settings
conn = sqlite3.connect(db_path, timeout=timeout, check_same_thread=False)

if row_factory:
conn.row_factory = sqlite3.Row

# Enable WAL mode for better concurrency
if enable_wal:
try:
conn.execute("PRAGMA journal_mode = WAL;")
except Exception as e:
logger.warning(f"Failed to enable WAL mode: {e}")

# Set busy timeout (milliseconds)
try:
conn.execute(f"PRAGMA busy_timeout = {int(timeout * 1000)};")
except Exception as e:
logger.warning(f"Failed to set busy_timeout: {e}")

# Load vector extension if requested
if enable_vector:
from .vector_operations import load_sqlite_vector_extension
load_sqlite_vector_extension(conn)
logger.debug(f"Vector extension loaded for connection to {db_path}")

return conn


@contextmanager
def db_connection(db_path: str, **kwargs):
"""
Context manager for database connections with automatic cleanup.

Args:
db_path: Path to the SQLite database file
**kwargs: Additional arguments passed to get_db_connection()

Yields:
sqlite3.Connection object

Example:
with db_connection(db_path) as conn:
cur = conn.cursor()
cur.execute("SELECT * FROM files")
results = cur.fetchall()
"""
conn = get_db_connection(db_path, **kwargs)
try:
yield conn
finally:
try:
conn.close()
except Exception as e:
logger.warning(f"Error closing database connection: {e}")
48 changes: 18 additions & 30 deletions db/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from utils.config import CFG # config (keeps chunk_size etc if needed)
from utils.logger import get_logger
from utils.cache import project_cache, stats_cache, file_cache
from utils.retry import retry_on_db_locked
from .db_writer import get_writer

_LOG = get_logger(__name__)
Expand All @@ -15,24 +16,18 @@
_PREPARED_STATEMENTS = {}
_PREPARED_LOCK = threading.Lock()

import threading

# Simple connection helper: we open new connections per operation so the code is robust
# across threads. We set WAL journal mode for safer concurrency.
# Added a small timeout to avoid long blocking if DB is locked.
def _get_connection(db_path: str) -> sqlite3.Connection:
dirname = os.path.dirname(os.path.abspath(db_path))
if dirname and not os.path.isdir(dirname):
os.makedirs(dirname, exist_ok=True)
# timeout in seconds for busy sqlite; small value to avoid long blocking in web requests
conn = sqlite3.connect(db_path, check_same_thread=False, timeout=5.0)
conn.row_factory = sqlite3.Row
try:
conn.execute("PRAGMA journal_mode = WAL;")
except Exception:
# Not fatal — continue
pass
return conn
"""
DEPRECATED: Use db.connection.get_db_connection() instead.
This function is maintained for backward compatibility.
"""
from .connection import get_db_connection
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot I don't care about backward comp atibility,remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the deprecated _get_connection wrapper function completely. Changes in commit 90fbbaf:

  • Removed wrapper function (lines 23-32)
  • Added direct import of get_db_connection from db.connection
  • Replaced all 18 usages with get_db_connection(path, timeout=5.0, enable_wal=True)

All database operations now use the unified connection function directly with no backward compatibility layer.

# Use shorter timeout for web requests (5s instead of default 30s)
return get_db_connection(db_path, timeout=5.0, enable_wal=True)


def _get_prepared_statement(conn: sqlite3.Connection, query_key: str, sql: str):
Expand Down Expand Up @@ -231,6 +226,8 @@ def clear_project_data(database_path: str) -> None:
cur.execute("DELETE FROM chunks")
# Delete files
cur.execute("DELETE FROM files")
# Clear vector metadata to allow re-indexing with different embedding dimensions
cur.execute("DELETE FROM vector_meta WHERE key = 'dimension'")
conn.commit()

# Invalidate caches
Expand Down Expand Up @@ -390,24 +387,15 @@ def _ensure_projects_dir():


def _retry_on_db_locked(func, *args, max_retries=DB_RETRY_COUNT, **kwargs):
"""Retry a database operation if it's locked."""
import time
last_error = None

for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except sqlite3.OperationalError as e:
if "database is locked" in str(e).lower() and attempt < max_retries - 1:
last_error = e
time.sleep(DB_RETRY_DELAY * (2 ** attempt)) # Exponential backoff
continue
raise
except Exception as e:
raise
"""
Retry a database operation if it's locked.

if last_error:
raise last_error
DEPRECATED: Use @retry_on_db_locked decorator from utils.retry instead.
This function is maintained for backward compatibility.
"""
# Use the retry decorator from utils.retry
decorated_func = retry_on_db_locked(max_retries=max_retries, base_delay=DB_RETRY_DELAY)(func)
return decorated_func(*args, **kwargs)


def _get_project_id(project_path: str) -> str:
Expand Down
84 changes: 51 additions & 33 deletions db/vector_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
"""
import os
import json
import time
import sqlite3
import importlib.resources
from typing import List, Dict, Any, Optional
from utils.logger import get_logger
from utils.retry import retry_on_exception

logger = get_logger(__name__)

Expand All @@ -18,7 +18,7 @@
SQLITE_VECTOR_RESOURCE = "vector"
SQLITE_VECTOR_VERSION_FN = "vector_version" # SELECT vector_version();

# Retry policy for DB-locked operations
# Retry policy for DB-locked operations (used by insert_chunk_vector_with_retry)
DB_LOCK_RETRY_COUNT = 6
DB_LOCK_RETRY_BASE_DELAY = 0.05 # seconds, exponential backoff multiplier

Expand All @@ -27,21 +27,18 @@ def connect_db(db_path: str, timeout: float = 30.0) -> sqlite3.Connection:
"""
Create a database connection with appropriate timeout and settings.

DEPRECATED: Use db.connection.get_db_connection() instead.
This function is maintained for backward compatibility.

Args:
db_path: Path to the SQLite database file
timeout: Timeout in seconds for waiting on locks

Returns:
sqlite3.Connection object configured for vector operations
"""
# timeout instructs sqlite to wait up to `timeout` seconds for locks
conn = sqlite3.connect(db_path, timeout=timeout, check_same_thread=False)
conn.row_factory = sqlite3.Row
try:
conn.execute("PRAGMA busy_timeout = 30000;") # 30s
except Exception:
pass
return conn
from .connection import get_db_connection
return get_db_connection(db_path, timeout=timeout, enable_vector=False)


def load_sqlite_vector_extension(conn: sqlite3.Connection) -> None:
Expand Down Expand Up @@ -163,30 +160,38 @@ def insert_chunk_vector_with_retry(conn: sqlite3.Connection, file_id: int, path:

q_vec = json.dumps(vector)

attempt = 0
while True:
# Use retry decorator for the actual insert operation
@retry_on_exception(
exceptions=(sqlite3.OperationalError,),
max_retries=DB_LOCK_RETRY_COUNT,
base_delay=DB_LOCK_RETRY_BASE_DELAY,
exponential_backoff=True
)
def _insert_with_retry():
"""Inner function with retry logic."""
# Check if it's a database locked error
try:
# use vector_as_f32(json) as per API so extension formats blob
cur.execute("INSERT INTO chunks (file_id, path, chunk_index, embedding) VALUES (?, ?, ?, vector_as_f32(?))",
(file_id, path, chunk_index, q_vec))
(file_id, path, chunk_index, q_vec))
conn.commit()
rowid = int(cur.lastrowid)
logger.debug(f"Inserted chunk vector for {path} chunk {chunk_index}, rowid={rowid}")
return rowid
except sqlite3.OperationalError as e:
msg = str(e).lower()
if "database is locked" in msg and attempt < DB_LOCK_RETRY_COUNT:
attempt += 1
delay = DB_LOCK_RETRY_BASE_DELAY * (2 ** (attempt - 1))
logger.warning(f"Database locked, retrying in {delay}s (attempt {attempt}/{DB_LOCK_RETRY_COUNT})")
time.sleep(delay)
continue
else:
logger.error(f"Failed to insert chunk vector after {attempt} retries: {e}")
# Only retry on database locked errors
if "database is locked" not in str(e).lower():
logger.error(f"Failed to insert chunk vector: {e}")
raise RuntimeError(f"Failed to INSERT chunk vector (vector_as_f32 call): {e}") from e
raise # Re-raise for retry decorator to handle
except Exception as e:
logger.error(f"Failed to insert chunk vector: {e}")
raise RuntimeError(f"Failed to INSERT chunk vector (vector_as_f32 call): {e}") from e

try:
return _insert_with_retry()
except sqlite3.OperationalError as e:
logger.error(f"Failed to insert chunk vector after {DB_LOCK_RETRY_COUNT} retries: {e}")
raise RuntimeError(f"Failed to INSERT chunk vector after retries: {e}") from e


def search_vectors(database_path: str, q_vector: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
Expand All @@ -204,14 +209,31 @@ def search_vectors(database_path: str, q_vector: List[float], top_k: int = 5) ->
Raises:
RuntimeError: If vector search operations fail
"""
from .connection import db_connection

logger.debug(f"Searching vectors in database: {database_path}, top_k={top_k}")
conn = connect_db(database_path)
try:
load_sqlite_vector_extension(conn)

with db_connection(database_path, enable_vector=True) as conn:
ensure_chunks_and_meta(conn)

q_json = json.dumps(q_vector)
# Ensure vector index is initialized before searching
cur = conn.cursor()
cur.execute("SELECT value FROM vector_meta WHERE key = 'dimension'")
row = cur.fetchone()
if not row:
# No dimension stored means no vectors have been indexed yet
logger.info("No vector dimension found in metadata - no chunks indexed yet")
return []

dim = int(row[0])
try:
conn.execute(f"SELECT vector_init('chunks', 'embedding', 'dimension={dim},type=FLOAT32,distance=COSINE')")
logger.debug(f"Vector index initialized for search with dimension {dim}")
except Exception as e:
logger.error(f"vector_init failed during search: {e}")
raise RuntimeError(f"vector_init failed during search: {e}") from e

q_json = json.dumps(q_vector)
try:
cur.execute(
"""
Expand All @@ -237,8 +259,6 @@ def search_vectors(database_path: str, q_vector: List[float], top_k: int = 5) ->
score = float(distance)
results.append({"file_id": int(file_id), "path": path, "chunk_index": int(chunk_index), "score": score})
return results
finally:
conn.close()


def get_chunk_text(database_path: str, file_id: int, chunk_index: int) -> Optional[str]:
Expand All @@ -255,9 +275,9 @@ def get_chunk_text(database_path: str, file_id: int, chunk_index: int) -> Option
The chunk text, or None if not found
"""
from .operations import get_project_metadata
from .connection import db_connection

conn = connect_db(database_path)
try:
with db_connection(database_path) as conn:
cur = conn.cursor()
# Get file path from database
cur.execute("SELECT path FROM files WHERE id = ?", (file_id,))
Expand Down Expand Up @@ -321,5 +341,3 @@ def get_chunk_text(database_path: str, file_id: int, chunk_index: int) -> Option
start = chunk_index * step
end = min(start + CHUNK_SIZE, len(content))
return content[start:end]
finally:
conn.close()
Loading