Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion sqlspec/adapters/aiosqlite/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import time
from contextlib import suppress
from inspect import isawaitable
from threading import Thread
from typing import TYPE_CHECKING, Any

import aiosqlite
Expand Down Expand Up @@ -265,6 +267,82 @@ def _database_name(self) -> str:
db = self._connection_parameters.get("database", "unknown")
return str(db).split("/")[-1] if db else "unknown"

def _set_connect_proxy_daemon(self, connect_proxy: Any) -> None:
"""Set daemon mode on aiosqlite worker thread before await.

aiosqlite <=0.21 used Connection as a Thread subclass.
aiosqlite >=0.22 stores an internal ``_thread`` attribute instead.
"""
try:
if isinstance(connect_proxy, Thread):
connect_proxy.daemon = True
return

worker_thread = connect_proxy._thread # pyright: ignore[reportAttributeAccessIssue]
if isinstance(worker_thread, Thread):
worker_thread.daemon = True
except Exception:
log_with_context(
logger,
logging.DEBUG,
"pool.connection.daemon.configure.error",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
database=self._database_name,
)

async def _force_stop_connection(self, connection: AiosqlitePoolConnection, *, reason: str) -> None:
"""Force-stop aiosqlite worker thread when graceful close times out."""
try:
stop_method = connection.connection.stop # pyright: ignore[reportAttributeAccessIssue]
except Exception:
log_with_context(
logger,
logging.DEBUG,
"pool.connection.force_stop.unavailable",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connection.id,
reason=reason,
)
return

try:
stop_result = stop_method()
if isawaitable(stop_result):
await asyncio.wait_for(stop_result, timeout=self._operation_timeout)
log_with_context(
logger,
logging.DEBUG,
"pool.connection.force_stop.success",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connection.id,
reason=reason,
)
except asyncio.TimeoutError:
log_with_context(
logger,
logging.WARNING,
"pool.connection.force_stop.timeout",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connection.id,
timeout_seconds=self._operation_timeout,
reason=reason,
)
except Exception as e:
log_with_context(
logger,
logging.WARNING,
"pool.connection.force_stop.error",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connection.id,
reason=reason,
error=str(e),
)

def size(self) -> int:
"""Get total number of connections in pool.

Expand All @@ -289,7 +367,9 @@ async def _create_connection(self) -> AiosqlitePoolConnection:
Returns:
New pool connection instance
"""
connection = await aiosqlite.connect(**self._connection_parameters)
connect_proxy = aiosqlite.connect(**self._connection_parameters)
self._set_connect_proxy_daemon(connect_proxy)
connection = await connect_proxy

database_path = str(self._connection_parameters.get("database", ""))
is_shared_cache = "cache=shared" in database_path
Expand Down Expand Up @@ -407,6 +487,7 @@ async def _retire_connection(self, connection: AiosqlitePoolConnection, *, reaso
connection_id=connection.id,
timeout_seconds=self._operation_timeout,
)
await self._force_stop_connection(connection, reason="retire_close_timeout")

async def _try_provision_new_connection(self) -> "AiosqlitePoolConnection | None":
"""Try to create a new connection if under capacity.
Expand Down Expand Up @@ -640,6 +721,8 @@ async def close(self) -> None:

for i, result in enumerate(results):
if isinstance(result, Exception):
if isinstance(result, asyncio.TimeoutError):
await self._force_stop_connection(connections[i], reason="pool_close_timeout")
log_with_context(
logger,
logging.WARNING,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/adapters/test_aiosqlite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for aiosqlite adapter internals."""
128 changes: 128 additions & 0 deletions tests/unit/adapters/test_aiosqlite/test_pool_shutdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Tests for aiosqlite pool shutdown behavior."""

# pyright: reportPrivateImportUsage = false, reportPrivateUsage = false

import asyncio
from threading import Thread
from typing import TYPE_CHECKING, Any, cast

import pytest

from sqlspec.adapters.aiosqlite.pool import AiosqliteConnectionPool, AiosqlitePoolConnection

if TYPE_CHECKING:
from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection


class _FakeAiosqliteConnection:
"""Minimal async connection stub used by pool tests."""

def __init__(self) -> None:
self.executed: list[str] = []
self.stop_called = 0

async def execute(self, sql: str) -> None:
self.executed.append(sql)

async def commit(self) -> None:
return None

async def rollback(self) -> None:
return None

async def close(self) -> None:
return None

def stop(self) -> None:
self.stop_called += 1


class _LegacyConnectProxy(Thread):
"""aiosqlite <=0.21 style connect proxy (thread subclass)."""

def __init__(self, connection: _FakeAiosqliteConnection) -> None:
super().__init__(target=lambda: None)
self._connection = connection

def __await__(self) -> Any:
async def _resolve() -> _FakeAiosqliteConnection:
return self._connection

return _resolve().__await__()


class _ModernConnectProxy:
"""aiosqlite >=0.22 style connect proxy (has internal _thread)."""

def __init__(self, connection: _FakeAiosqliteConnection) -> None:
self._thread = Thread(target=lambda: None)
self._connection = connection

def __await__(self) -> Any:
async def _resolve() -> _FakeAiosqliteConnection:
return self._connection

return _resolve().__await__()


@pytest.mark.asyncio
async def test_create_connection_sets_daemon_for_legacy_proxy(monkeypatch: pytest.MonkeyPatch) -> None:
"""Pool should set daemon mode for pre-0.22 thread-based connect proxy."""
from sqlspec.adapters.aiosqlite import pool as pool_module

fake_connection = _FakeAiosqliteConnection()
connect_proxy = _LegacyConnectProxy(fake_connection)
pool = AiosqliteConnectionPool({"database": ":memory:"})

monkeypatch.setattr(pool_module.aiosqlite, "connect", lambda **_: connect_proxy)

pool_connection = await pool._create_connection()
try:
assert connect_proxy.daemon is True
finally:
await pool._retire_connection(pool_connection, reason="test_cleanup")


@pytest.mark.asyncio
async def test_create_connection_sets_daemon_for_modern_proxy(monkeypatch: pytest.MonkeyPatch) -> None:
"""Pool should set daemon mode for 0.22+ connect proxy internal worker thread."""
from sqlspec.adapters.aiosqlite import pool as pool_module

fake_connection = _FakeAiosqliteConnection()
connect_proxy = _ModernConnectProxy(fake_connection)
pool = AiosqliteConnectionPool({"database": ":memory:"})

monkeypatch.setattr(pool_module.aiosqlite, "connect", lambda **_: connect_proxy)

pool_connection = await pool._create_connection()
try:
assert connect_proxy._thread.daemon is True
finally:
await pool._retire_connection(pool_connection, reason="test_cleanup")


@pytest.mark.asyncio
async def test_pool_close_uses_force_stop_when_close_times_out(monkeypatch: pytest.MonkeyPatch) -> None:
"""Pool should trigger force-stop fallback when graceful close times out."""
from sqlspec.adapters.aiosqlite import pool as pool_module

calls: list[tuple[str, str]] = []

async def _hanging_close(self: AiosqlitePoolConnection) -> None:
await asyncio.sleep(0.05)

async def _capture_force_stop(
self: AiosqliteConnectionPool, connection: AiosqlitePoolConnection, *, reason: str
) -> None:
calls.append((connection.id, reason))

monkeypatch.setattr(pool_module.AiosqlitePoolConnection, "close", _hanging_close)
monkeypatch.setattr(pool_module.AiosqliteConnectionPool, "_force_stop_connection", _capture_force_stop)

pool = AiosqliteConnectionPool({"database": ":memory:"}, operation_timeout=0.001)
connection = AiosqlitePoolConnection(cast("AiosqliteConnection", _FakeAiosqliteConnection()))
pool._connection_registry[connection.id] = connection

await pool.close()

assert calls == [(connection.id, "pool_close_timeout")]
Loading