diff --git a/sqlspec/adapters/aiosqlite/pool.py b/sqlspec/adapters/aiosqlite/pool.py index bdabcc3f..fae8ccf1 100644 --- a/sqlspec/adapters/aiosqlite/pool.py +++ b/sqlspec/adapters/aiosqlite/pool.py @@ -4,6 +4,8 @@ import logging import time from contextlib import suppress +from inspect import isawaitable +from threading import Thread from typing import TYPE_CHECKING, Any import aiosqlite @@ -265,6 +267,82 @@ def _database_name(self) -> str: db = self._connection_parameters.get("database", "unknown") return str(db).split("/")[-1] if db else "unknown" + def _set_connect_proxy_daemon(self, connect_proxy: Any) -> None: + """Set daemon mode on aiosqlite worker thread before await. + + aiosqlite <=0.21 used Connection as a Thread subclass. + aiosqlite >=0.22 stores an internal ``_thread`` attribute instead. + """ + try: + if isinstance(connect_proxy, Thread): + connect_proxy.daemon = True + return + + worker_thread = connect_proxy._thread # pyright: ignore[reportAttributeAccessIssue] + if isinstance(worker_thread, Thread): + worker_thread.daemon = True + except Exception: + log_with_context( + logger, + logging.DEBUG, + "pool.connection.daemon.configure.error", + adapter=_ADAPTER_NAME, + pool_id=self._pool_id, + database=self._database_name, + ) + + async def _force_stop_connection(self, connection: AiosqlitePoolConnection, *, reason: str) -> None: + """Force-stop aiosqlite worker thread when graceful close times out.""" + try: + stop_method = connection.connection.stop # pyright: ignore[reportAttributeAccessIssue] + except Exception: + log_with_context( + logger, + logging.DEBUG, + "pool.connection.force_stop.unavailable", + adapter=_ADAPTER_NAME, + pool_id=self._pool_id, + connection_id=connection.id, + reason=reason, + ) + return + + try: + stop_result = stop_method() + if isawaitable(stop_result): + await asyncio.wait_for(stop_result, timeout=self._operation_timeout) + log_with_context( + logger, + logging.DEBUG, + "pool.connection.force_stop.success", + adapter=_ADAPTER_NAME, + pool_id=self._pool_id, + connection_id=connection.id, + reason=reason, + ) + except asyncio.TimeoutError: + log_with_context( + logger, + logging.WARNING, + "pool.connection.force_stop.timeout", + adapter=_ADAPTER_NAME, + pool_id=self._pool_id, + connection_id=connection.id, + timeout_seconds=self._operation_timeout, + reason=reason, + ) + except Exception as e: + log_with_context( + logger, + logging.WARNING, + "pool.connection.force_stop.error", + adapter=_ADAPTER_NAME, + pool_id=self._pool_id, + connection_id=connection.id, + reason=reason, + error=str(e), + ) + def size(self) -> int: """Get total number of connections in pool. @@ -289,7 +367,9 @@ async def _create_connection(self) -> AiosqlitePoolConnection: Returns: New pool connection instance """ - connection = await aiosqlite.connect(**self._connection_parameters) + connect_proxy = aiosqlite.connect(**self._connection_parameters) + self._set_connect_proxy_daemon(connect_proxy) + connection = await connect_proxy database_path = str(self._connection_parameters.get("database", "")) is_shared_cache = "cache=shared" in database_path @@ -407,6 +487,7 @@ async def _retire_connection(self, connection: AiosqlitePoolConnection, *, reaso connection_id=connection.id, timeout_seconds=self._operation_timeout, ) + await self._force_stop_connection(connection, reason="retire_close_timeout") async def _try_provision_new_connection(self) -> "AiosqlitePoolConnection | None": """Try to create a new connection if under capacity. @@ -640,6 +721,8 @@ async def close(self) -> None: for i, result in enumerate(results): if isinstance(result, Exception): + if isinstance(result, asyncio.TimeoutError): + await self._force_stop_connection(connections[i], reason="pool_close_timeout") log_with_context( logger, logging.WARNING, diff --git a/tests/unit/adapters/test_aiosqlite/__init__.py b/tests/unit/adapters/test_aiosqlite/__init__.py new file mode 100644 index 00000000..58052a5a --- /dev/null +++ b/tests/unit/adapters/test_aiosqlite/__init__.py @@ -0,0 +1 @@ +"""Unit tests for aiosqlite adapter internals.""" diff --git a/tests/unit/adapters/test_aiosqlite/test_pool_shutdown.py b/tests/unit/adapters/test_aiosqlite/test_pool_shutdown.py new file mode 100644 index 00000000..3918980e --- /dev/null +++ b/tests/unit/adapters/test_aiosqlite/test_pool_shutdown.py @@ -0,0 +1,128 @@ +"""Tests for aiosqlite pool shutdown behavior.""" + +# pyright: reportPrivateImportUsage = false, reportPrivateUsage = false + +import asyncio +from threading import Thread +from typing import TYPE_CHECKING, Any, cast + +import pytest + +from sqlspec.adapters.aiosqlite.pool import AiosqliteConnectionPool, AiosqlitePoolConnection + +if TYPE_CHECKING: + from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection + + +class _FakeAiosqliteConnection: + """Minimal async connection stub used by pool tests.""" + + def __init__(self) -> None: + self.executed: list[str] = [] + self.stop_called = 0 + + async def execute(self, sql: str) -> None: + self.executed.append(sql) + + async def commit(self) -> None: + return None + + async def rollback(self) -> None: + return None + + async def close(self) -> None: + return None + + def stop(self) -> None: + self.stop_called += 1 + + +class _LegacyConnectProxy(Thread): + """aiosqlite <=0.21 style connect proxy (thread subclass).""" + + def __init__(self, connection: _FakeAiosqliteConnection) -> None: + super().__init__(target=lambda: None) + self._connection = connection + + def __await__(self) -> Any: + async def _resolve() -> _FakeAiosqliteConnection: + return self._connection + + return _resolve().__await__() + + +class _ModernConnectProxy: + """aiosqlite >=0.22 style connect proxy (has internal _thread).""" + + def __init__(self, connection: _FakeAiosqliteConnection) -> None: + self._thread = Thread(target=lambda: None) + self._connection = connection + + def __await__(self) -> Any: + async def _resolve() -> _FakeAiosqliteConnection: + return self._connection + + return _resolve().__await__() + + +@pytest.mark.asyncio +async def test_create_connection_sets_daemon_for_legacy_proxy(monkeypatch: pytest.MonkeyPatch) -> None: + """Pool should set daemon mode for pre-0.22 thread-based connect proxy.""" + from sqlspec.adapters.aiosqlite import pool as pool_module + + fake_connection = _FakeAiosqliteConnection() + connect_proxy = _LegacyConnectProxy(fake_connection) + pool = AiosqliteConnectionPool({"database": ":memory:"}) + + monkeypatch.setattr(pool_module.aiosqlite, "connect", lambda **_: connect_proxy) + + pool_connection = await pool._create_connection() + try: + assert connect_proxy.daemon is True + finally: + await pool._retire_connection(pool_connection, reason="test_cleanup") + + +@pytest.mark.asyncio +async def test_create_connection_sets_daemon_for_modern_proxy(monkeypatch: pytest.MonkeyPatch) -> None: + """Pool should set daemon mode for 0.22+ connect proxy internal worker thread.""" + from sqlspec.adapters.aiosqlite import pool as pool_module + + fake_connection = _FakeAiosqliteConnection() + connect_proxy = _ModernConnectProxy(fake_connection) + pool = AiosqliteConnectionPool({"database": ":memory:"}) + + monkeypatch.setattr(pool_module.aiosqlite, "connect", lambda **_: connect_proxy) + + pool_connection = await pool._create_connection() + try: + assert connect_proxy._thread.daemon is True + finally: + await pool._retire_connection(pool_connection, reason="test_cleanup") + + +@pytest.mark.asyncio +async def test_pool_close_uses_force_stop_when_close_times_out(monkeypatch: pytest.MonkeyPatch) -> None: + """Pool should trigger force-stop fallback when graceful close times out.""" + from sqlspec.adapters.aiosqlite import pool as pool_module + + calls: list[tuple[str, str]] = [] + + async def _hanging_close(self: AiosqlitePoolConnection) -> None: + await asyncio.sleep(0.05) + + async def _capture_force_stop( + self: AiosqliteConnectionPool, connection: AiosqlitePoolConnection, *, reason: str + ) -> None: + calls.append((connection.id, reason)) + + monkeypatch.setattr(pool_module.AiosqlitePoolConnection, "close", _hanging_close) + monkeypatch.setattr(pool_module.AiosqliteConnectionPool, "_force_stop_connection", _capture_force_stop) + + pool = AiosqliteConnectionPool({"database": ":memory:"}, operation_timeout=0.001) + connection = AiosqlitePoolConnection(cast("AiosqliteConnection", _FakeAiosqliteConnection())) + pool._connection_registry[connection.id] = connection + + await pool.close() + + assert calls == [(connection.id, "pool_close_timeout")]