Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions db/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Unified database connection utilities.
Provides consistent connection management across all database operations.
"""
import os
import sqlite3
from typing import Optional
from contextlib import contextmanager
from utils.logger import get_logger

logger = get_logger(__name__)


def get_db_connection(
db_path: str,
timeout: float = 30.0,
enable_wal: bool = True,
enable_vector: bool = False,
row_factory: bool = True
) -> sqlite3.Connection:
"""
Create a database connection with consistent configuration.

Args:
db_path: Path to the SQLite database file
timeout: Timeout in seconds for waiting on locks (default: 30.0)
enable_wal: Enable Write-Ahead Logging mode (default: True)
enable_vector: Load sqlite-vector extension (default: False)
row_factory: Use sqlite3.Row factory for dict-like access (default: True)

Returns:
sqlite3.Connection object configured for the specified operations

Raises:
RuntimeError: If vector extension fails to load when enable_vector=True
"""
# Create directory if needed
dirname = os.path.dirname(os.path.abspath(db_path))
if dirname and not os.path.isdir(dirname):
os.makedirs(dirname, exist_ok=True)

# Create connection with consistent settings
conn = sqlite3.connect(db_path, timeout=timeout, check_same_thread=False)

if row_factory:
conn.row_factory = sqlite3.Row

# Enable WAL mode for better concurrency
if enable_wal:
try:
conn.execute("PRAGMA journal_mode = WAL;")
except Exception as e:
logger.warning(f"Failed to enable WAL mode: {e}")

# Set busy timeout (milliseconds)
try:
conn.execute(f"PRAGMA busy_timeout = {int(timeout * 1000)};")
except Exception as e:
logger.warning(f"Failed to set busy_timeout: {e}")

# Load vector extension if requested
if enable_vector:
from .vector_operations import load_sqlite_vector_extension
load_sqlite_vector_extension(conn)
logger.debug(f"Vector extension loaded for connection to {db_path}")

return conn


@contextmanager
def db_connection(db_path: str, **kwargs):
"""
Context manager for database connections with automatic cleanup.

Args:
db_path: Path to the SQLite database file
**kwargs: Additional arguments passed to get_db_connection()

Yields:
sqlite3.Connection object

Example:
with db_connection(db_path) as conn:
cur = conn.cursor()
cur.execute("SELECT * FROM files")
results = cur.fetchall()
"""
conn = get_db_connection(db_path, **kwargs)
try:
yield conn
finally:
try:
conn.close()
except Exception as e:
logger.warning(f"Error closing database connection: {e}")
113 changes: 45 additions & 68 deletions db/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from utils.config import CFG # config (keeps chunk_size etc if needed)
from utils.logger import get_logger
from utils.cache import project_cache, stats_cache, file_cache
from utils.retry import retry_on_db_locked
from .db_writer import get_writer

_LOG = get_logger(__name__)
Expand All @@ -15,24 +16,18 @@
_PREPARED_STATEMENTS = {}
_PREPARED_LOCK = threading.Lock()

import threading

# Simple connection helper: we open new connections per operation so the code is robust
# across threads. We set WAL journal mode for safer concurrency.
# Added a small timeout to avoid long blocking if DB is locked.
def _get_connection(db_path: str) -> sqlite3.Connection:
dirname = os.path.dirname(os.path.abspath(db_path))
if dirname and not os.path.isdir(dirname):
os.makedirs(dirname, exist_ok=True)
# timeout in seconds for busy sqlite; small value to avoid long blocking in web requests
conn = sqlite3.connect(db_path, check_same_thread=False, timeout=5.0)
conn.row_factory = sqlite3.Row
try:
conn.execute("PRAGMA journal_mode = WAL;")
except Exception:
# Not fatal — continue
pass
return conn
"""
DEPRECATED: Use db.connection.get_db_connection() instead.
This function is maintained for backward compatibility.
"""
from .connection import get_db_connection
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot I don't care about backward comp atibility,remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the deprecated _get_connection wrapper function completely. Changes in commit 90fbbaf:

  • Removed wrapper function (lines 23-32)
  • Added direct import of get_db_connection from db.connection
  • Replaced all 18 usages with get_db_connection(path, timeout=5.0, enable_wal=True)

All database operations now use the unified connection function directly with no backward compatibility layer.

# Use shorter timeout for web requests (5s instead of default 30s)
return get_db_connection(db_path, timeout=5.0, enable_wal=True)


def _get_prepared_statement(conn: sqlite3.Connection, query_key: str, sql: str):
Expand Down Expand Up @@ -231,6 +226,8 @@ def clear_project_data(database_path: str) -> None:
cur.execute("DELETE FROM chunks")
# Delete files
cur.execute("DELETE FROM files")
# Clear vector metadata to allow re-indexing with different embedding dimensions
cur.execute("DELETE FROM vector_meta WHERE key = 'dimension'")
conn.commit()

# Invalidate caches
Expand Down Expand Up @@ -389,27 +386,6 @@ def _ensure_projects_dir():
raise


def _retry_on_db_locked(func, *args, max_retries=DB_RETRY_COUNT, **kwargs):
"""Retry a database operation if it's locked."""
import time
last_error = None

for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except sqlite3.OperationalError as e:
if "database is locked" in str(e).lower() and attempt < max_retries - 1:
last_error = e
time.sleep(DB_RETRY_DELAY * (2 ** attempt)) # Exponential backoff
continue
raise
except Exception as e:
raise

if last_error:
raise last_error


def _get_project_id(project_path: str) -> str:
"""Generate a stable project ID from the project path."""
import hashlib
Expand All @@ -428,40 +404,34 @@ def _get_projects_registry_path() -> str:
return os.path.join(PROJECTS_DIR, "registry.db")


@retry_on_db_locked(max_retries=DB_RETRY_COUNT, base_delay=DB_RETRY_DELAY)
def _init_registry_db():
"""Initialize the projects registry database with proper configuration."""
registry_path = _get_projects_registry_path()

def _init():
conn = _get_connection(registry_path)
try:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS projects (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
path TEXT NOT NULL UNIQUE,
database_path TEXT NOT NULL,
created_at TEXT DEFAULT (datetime('now')),
last_indexed_at TEXT,
status TEXT DEFAULT 'created',
settings TEXT
)
"""
)
conn.commit()
except Exception as e:
_LOG.error(f"Failed to initialize registry database: {e}")
raise
finally:
conn.close()

conn = _get_connection(registry_path)
try:
_retry_on_db_locked(_init)
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS projects (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
path TEXT NOT NULL UNIQUE,
database_path TEXT NOT NULL,
created_at TEXT DEFAULT (datetime('now')),
last_indexed_at TEXT,
status TEXT DEFAULT 'created',
settings TEXT
)
"""
)
conn.commit()
except Exception as e:
_LOG.error(f"Failed to initialize registry after retries: {e}")
_LOG.error(f"Failed to initialize registry database: {e}")
raise
finally:
conn.close()


def create_project(project_path: str, name: Optional[str] = None) -> Dict[str, Any]:
Expand Down Expand Up @@ -522,6 +492,7 @@ def create_project(project_path: str, name: Optional[str] = None) -> Dict[str, A

registry_path = _get_projects_registry_path()

@retry_on_db_locked(max_retries=DB_RETRY_COUNT, base_delay=DB_RETRY_DELAY)
def _create():
conn = _get_connection(registry_path)
try:
Expand Down Expand Up @@ -563,7 +534,7 @@ def _create():
conn.close()

try:
result = _retry_on_db_locked(_create)
result = _create()
return result
except Exception as e:
_LOG.error(f"Failed to create project: {e}")
Expand All @@ -583,6 +554,7 @@ def get_project(project_path: str) -> Optional[Dict[str, Any]]:

registry_path = _get_projects_registry_path()

@retry_on_db_locked(max_retries=DB_RETRY_COUNT, base_delay=DB_RETRY_DELAY)
def _get():
conn = _get_connection(registry_path)
try:
Expand All @@ -596,7 +568,7 @@ def _get():
finally:
conn.close()

return _retry_on_db_locked(_get)
return _get()


def get_project_by_id(project_id: str) -> Optional[Dict[str, Any]]:
Expand All @@ -611,6 +583,7 @@ def get_project_by_id(project_id: str) -> Optional[Dict[str, Any]]:

registry_path = _get_projects_registry_path()

@retry_on_db_locked(max_retries=DB_RETRY_COUNT, base_delay=DB_RETRY_DELAY)
def _get():
conn = _get_connection(registry_path)
try:
Expand All @@ -624,7 +597,7 @@ def _get():
finally:
conn.close()

return _retry_on_db_locked(_get)
return _get()


def list_projects() -> List[Dict[str, Any]]:
Expand All @@ -633,6 +606,7 @@ def list_projects() -> List[Dict[str, Any]]:

registry_path = _get_projects_registry_path()

@retry_on_db_locked(max_retries=DB_RETRY_COUNT, base_delay=DB_RETRY_DELAY)
def _list():
conn = _get_connection(registry_path)
try:
Expand All @@ -643,7 +617,7 @@ def _list():
finally:
conn.close()

return _retry_on_db_locked(_list)
return _list()


def update_project_status(project_id: str, status: str, last_indexed_at: Optional[str] = None):
Expand All @@ -652,6 +626,7 @@ def update_project_status(project_id: str, status: str, last_indexed_at: Optiona

registry_path = _get_projects_registry_path()

@retry_on_db_locked(max_retries=DB_RETRY_COUNT, base_delay=DB_RETRY_DELAY)
def _update():
conn = _get_connection(registry_path)
try:
Expand All @@ -670,7 +645,7 @@ def _update():
finally:
conn.close()

_retry_on_db_locked(_update)
_update()
# Invalidate cache after update
project_cache.invalidate(f"project:id:{project_id}")

Expand All @@ -682,6 +657,7 @@ def update_project_settings(project_id: str, settings: Dict[str, Any]):

registry_path = _get_projects_registry_path()

@retry_on_db_locked(max_retries=DB_RETRY_COUNT, base_delay=DB_RETRY_DELAY)
def _update():
conn = _get_connection(registry_path)
try:
Expand All @@ -694,7 +670,7 @@ def _update():
finally:
conn.close()

_retry_on_db_locked(_update)
_update()
# Invalidate cache after update
project_cache.invalidate(f"project:id:{project_id}")

Expand All @@ -716,6 +692,7 @@ def delete_project(project_id: str):

registry_path = _get_projects_registry_path()

@retry_on_db_locked(max_retries=DB_RETRY_COUNT, base_delay=DB_RETRY_DELAY)
def _delete():
conn = _get_connection(registry_path)
try:
Expand All @@ -725,7 +702,7 @@ def _delete():
finally:
conn.close()

_retry_on_db_locked(_delete)
_delete()
# Invalidate cache after deletion
project_cache.invalidate(f"project:id:{project_id}")
if project.get("path"):
Expand Down
Loading
Loading