diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 5712c3ef..c4e92ef8 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -75,6 +75,33 @@ jobs: steps: - uses: actions/checkout@v3 + - name: Patch version for pre-release tag + if: github.event_name == 'push' + shell: python + run: | + import re, os, pathlib + tag = os.environ["GITHUB_REF_NAME"].lstrip("v") + patterns = [ + (r"^(\d+\.\d+\.\d+)-test(\d+)$", r"\1.dev\2"), + (r"^(\d+\.\d+\.\d+)-alpha(\d+)$", r"\1a\2"), + (r"^(\d+\.\d+\.\d+)-beta(\d+)$", r"\1b\2"), + (r"^(\d+\.\d+\.\d+)-rc(\d+)$", r"\1rc\2"), + ] + version = None + for pattern, repl in patterns: + m = re.match(pattern, tag) + if m: + version = re.sub(pattern, repl, tag) + break + if version is None: + print("No pre-release suffix, keeping version as-is") + else: + print(f"Patching version to: {version}") + p = pathlib.Path("pyproject.toml") + content = p.read_text() + content = re.sub(r'^version = ".*"', f'version = "{version}"', content, count=1, flags=re.MULTILINE) + p.write_text(content) + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: @@ -207,7 +234,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Publish to PyPI - if: ${{ github.event_name == 'release' || github.event.inputs.publish_pypi == 'true' }} + if: ${{ github.event_name == 'release' || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish_pypi == 'true') }} uses: pypa/gh-action-pypi-publish@release/v1 # - name: Publish Conda package diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 331828d1..294b17c2 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -24,6 +24,7 @@ """ import argparse import asyncio +import atexit import contextvars import dataclasses import datetime @@ -113,6 +114,91 @@ async def to_thread( return await loop.run_in_executor(None, func_call) +async def _poll_cancel(cancel_event: threading.Event) -> None: + while not cancel_event.is_set(): + await asyncio.sleep(0.1) + + +async def _cancellable_run( + cancel_event: threading.Event, + coro: Any, +) -> Any: + task = asyncio.create_task(coro) + cancel_check = asyncio.create_task(_poll_cancel(cancel_event)) + done, pending = await asyncio.wait( + [task, cancel_check], return_when=asyncio.FIRST_COMPLETED, + ) + for p in pending: + p.cancel() + if cancel_check in done: + task.cancel() + raise asyncio.CancelledError() + return task.result() + + +# Each `to_thread` worker thread owns a long-lived event loop reused across +# requests, so loop-bound resources (HTTP pools, DB sessions, sockets) can +# survive between calls handled by the same thread. +_thread_local = threading.local() +_loop_registry: 'Set[asyncio.AbstractEventLoop]' = set() +_loop_registry_lock = threading.Lock() + + +def _get_thread_loop() -> asyncio.AbstractEventLoop: + """Return (creating if needed) the calling thread's persistent loop.""" + loop = getattr(_thread_local, 'loop', None) + if loop is None or loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + _thread_local.loop = loop + with _loop_registry_lock: + _loop_registry.add(loop) + return loop + + +def _run_on_thread_loop(coro: Any) -> Any: + """ + Run ``coro`` on the calling thread's persistent loop. + + The loop is never closed between calls, so loop-bound resources (e.g. + httpx keep-alive pools) survive across requests and the deferred + "Event loop is closed" errors thrown by httpx/anyio at teardown do not + occur. + + Caveat: tasks the user code spawns via ``asyncio.create_task`` and + leaves running outlive the current call too. That is the price of + keeping shared resources alive; ``cancel_event`` does not reach them. + """ + loop = _get_thread_loop() + return loop.run_until_complete(coro) + + +def _shutdown_thread_loops() -> None: + """Best-effort cleanup of all persistent worker-thread loops at exit.""" + with _loop_registry_lock: + loops = list(_loop_registry) + _loop_registry.clear() + + for loop in loops: + if loop.is_closed(): + continue + try: + # Owning thread is no longer running the loop; safe to drive + # teardown from this (exiting) thread. + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.run_until_complete(loop.shutdown_default_executor()) + except Exception: + pass + finally: + try: + loop.close() + except Exception: + pass + + +atexit.register(_shutdown_thread_loops) + + # Use negative values to indicate unsigned ints / binary data / usec time precision rowdat_1_type_map = { 'bool': ft.LONGLONG, @@ -1196,11 +1282,12 @@ async def __call__( ) func_task = asyncio.create_task( - func(cancel_event, call_timer, *inputs) - if func_info['is_async'] - else to_thread( - lambda: asyncio.run( - func(cancel_event, call_timer, *inputs), + to_thread( + lambda: _run_on_thread_loop( + _cancellable_run( + cancel_event, + func(cancel_event, call_timer, *inputs), + ), ), ), ) @@ -1219,17 +1306,21 @@ async def __call__( all_tasks, return_when=asyncio.FIRST_COMPLETED, ) + # Signal the worker before awaiting cancellation: cancelling + # func_task only flips its asyncio wrapper, not the executor + # work; only cancel_event reaches the worker loop. + if func_task in pending: + cancel_event.set() + await cancel_all_tasks(pending) for task in done: if task is disconnect_task: - cancel_event.set() raise asyncio.CancelledError( 'Function call was cancelled by client disconnect', ) elif task is timeout_task: - cancel_event.set() raise asyncio.TimeoutError( 'Function call was cancelled due to timeout', ) @@ -1292,6 +1383,7 @@ async def __call__( await send(self.error_response_dict) finally: + cancel_event.set() await cancel_all_tasks(all_tasks) # Handle api reflection diff --git a/singlestoredb/tests/test_udf_event_loop.py b/singlestoredb/tests/test_udf_event_loop.py new file mode 100644 index 00000000..bde47389 --- /dev/null +++ b/singlestoredb/tests/test_udf_event_loop.py @@ -0,0 +1,292 @@ +"""Tests for the async UDF persistent per-thread event loop.""" +import asyncio +import contextvars +import threading +import time +import unittest +from typing import Any +from typing import List + +from ..functions.ext.asgi import _cancellable_run +from ..functions.ext.asgi import _get_thread_loop +from ..functions.ext.asgi import _run_on_thread_loop +from ..functions.ext.asgi import to_thread + + +class TestUDFDispatchEdgeCases(unittest.TestCase): + """Test edge cases in the UDF dispatch stack.""" + + def test_timeout_cancels_running_function(self) -> None: + """Cancel event set from timer thread cancels a blocked coroutine.""" + cancel_event = threading.Event() + + async def long_running() -> str: + await asyncio.sleep(999) + return 'should not reach' + + def set_cancel_after_delay() -> None: + time.sleep(0.2) + cancel_event.set() + + timer = threading.Thread(target=set_cancel_after_delay) + timer.start() + + start = time.monotonic() + with self.assertRaises(asyncio.CancelledError): + _run_on_thread_loop( + _cancellable_run(cancel_event, long_running()), + ) + elapsed = time.monotonic() - start + timer.join() + # 0.2s delay + up to 0.1s poll interval + margin + self.assertLess(elapsed, 0.5) + + def test_exception_propagates_through_full_stack(self) -> None: + """User exception propagates unwrapped through the entire dispatch.""" + cancel_event = threading.Event() + + class CustomUDFError(Exception): + pass + + async def failing_udf() -> None: + raise CustomUDFError('embedding service unavailable') + + with self.assertRaises(CustomUDFError) as ctx: + _run_on_thread_loop( + _cancellable_run(cancel_event, failing_udf()), + ) + self.assertEqual(str(ctx.exception), 'embedding service unavailable') + + def test_cancel_event_detected_within_poll_interval(self) -> None: + """Cancellation is detected within one poll cycle (0.1s).""" + cancel_event = threading.Event() + + async def blocked() -> str: + await asyncio.sleep(999) + return 'unreachable' + + def set_cancel() -> None: + time.sleep(0.05) + cancel_event.set() + + timer = threading.Thread(target=set_cancel) + timer.start() + + start = time.monotonic() + with self.assertRaises(asyncio.CancelledError): + _run_on_thread_loop( + _cancellable_run(cancel_event, blocked()), + ) + elapsed = time.monotonic() - start + timer.join() + # 0.05s delay + 0.1s poll interval + margin + self.assertLess(elapsed, 0.25) + + def test_context_vars_propagate_through_to_thread(self) -> None: + """Context variables are visible inside to_thread executor.""" + test_var: contextvars.ContextVar[str] = contextvars.ContextVar( + 'test_var', + ) + test_var.set('hello_from_parent') + captured: List[str] = [] + + def read_context_var() -> str: + val = test_var.get('NOT_FOUND') + captured.append(val) + return val + + async def run_in_thread() -> str: + return await to_thread(read_context_var) + + result = _run_on_thread_loop(run_in_thread()) + self.assertEqual(result, 'hello_from_parent') + self.assertEqual(captured, ['hello_from_parent']) + + def test_concurrent_requests_isolated(self) -> None: + """Parallel executions don't share state.""" + results: List[Any] = [None, None, None] + + def run_isolated(index: int) -> None: + async def compute() -> int: + await asyncio.sleep(0.05) + return index * 10 + + results[index] = _run_on_thread_loop(compute()) + + threads = [ + threading.Thread(target=run_isolated, args=(i,)) + for i in range(3) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(results, [0, 10, 20]) + + def test_sync_function_through_async_wrapper(self) -> None: + """Synchronous function works when wrapped as async coroutine.""" + cancel_event = threading.Event() + + async def sync_as_async() -> int: + # Simulates what decorator.py's async_wrapper does for sync UDFs + return 42 + 1 + + result = _run_on_thread_loop( + _cancellable_run(cancel_event, sync_as_async()), + ) + self.assertEqual(result, 43) + + def test_cancel_event_not_set_on_success(self) -> None: + """Cancel event remains unset after successful execution.""" + cancel_event = threading.Event() + + async def quick() -> str: + return 'fast' + + result = _run_on_thread_loop( + _cancellable_run(cancel_event, quick()), + ) + self.assertEqual(result, 'fast') + self.assertFalse(cancel_event.is_set()) + + +class TestRunOnThreadLoop(unittest.TestCase): + """Test _run_on_thread_loop reuses a persistent per-thread event loop.""" + + def test_basic_coroutine(self) -> None: + async def simple() -> int: + return 42 + + self.assertEqual(_run_on_thread_loop(simple()), 42) + + def test_loop_reused_across_calls(self) -> None: + """The same loop object is reused for successive calls in a thread.""" + loops: List[asyncio.AbstractEventLoop] = [] + + async def capture_loop() -> bool: + loops.append(asyncio.get_running_loop()) + return True + + _run_on_thread_loop(capture_loop()) + _run_on_thread_loop(capture_loop()) + + self.assertIs(loops[0], loops[1]) + + def test_loop_not_closed_between_calls(self) -> None: + """The persistent loop stays open so resources survive requests.""" + captured: List[asyncio.AbstractEventLoop] = [] + + async def capture_loop() -> bool: + captured.append(asyncio.get_running_loop()) + return True + + _run_on_thread_loop(capture_loop()) + loop = captured[0] + self.assertFalse(loop.is_closed()) + + # Still usable for the next request. + _run_on_thread_loop(capture_loop()) + self.assertFalse(loop.is_closed()) + + def test_async_resource_survives_between_calls(self) -> None: + """An object bound to the loop can be reused on the next call. + + This mirrors caching e.g. an httpx.AsyncClient keyed by the loop and + reusing its connection pool on subsequent requests. + """ + clients: dict[asyncio.AbstractEventLoop, object] = {} + + async def get_or_create_client() -> int: + loop = asyncio.get_running_loop() + if loop not in clients: + clients[loop] = object() + return id(clients[loop]) + + first = _run_on_thread_loop(get_or_create_client()) + second = _run_on_thread_loop(get_or_create_client()) + + self.assertEqual(first, second) + self.assertEqual(len(clients), 1) + + def test_separate_threads_get_separate_loops(self) -> None: + """Each worker thread owns its own persistent loop.""" + loops: List[asyncio.AbstractEventLoop] = [] + lock = threading.Lock() + + def run_in_thread() -> None: + async def capture() -> bool: + with lock: + loops.append(asyncio.get_running_loop()) + return True + + _run_on_thread_loop(capture()) + + threads = [threading.Thread(target=run_in_thread) for _ in range(3)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(len(loops), 3) + self.assertEqual(len({id(loop) for loop in loops}), 3) + + def test_get_thread_loop_idempotent(self) -> None: + """_get_thread_loop returns the same loop on repeated calls.""" + def run_in_thread(out: List[asyncio.AbstractEventLoop]) -> None: + out.append(_get_thread_loop()) + out.append(_get_thread_loop()) + + out: List[asyncio.AbstractEventLoop] = [] + t = threading.Thread(target=run_in_thread, args=(out,)) + t.start() + t.join() + + self.assertIs(out[0], out[1]) + + def test_exception_propagates(self) -> None: + async def failing() -> None: + raise ValueError('test error') + + with self.assertRaises(ValueError) as ctx: + _run_on_thread_loop(failing()) + self.assertEqual(str(ctx.exception), 'test error') + + def test_cancellable_run_integration(self) -> None: + """_cancellable_run works on the persistent loop.""" + cancel_event = threading.Event() + + async def slow_func() -> str: + return 'completed' + + result = _run_on_thread_loop( + _cancellable_run(cancel_event, slow_func()), + ) + self.assertEqual(result, 'completed') + + def test_cancellation_via_event(self) -> None: + """Cancellation propagates through the persistent-loop stack.""" + cancel_event = threading.Event() + cancel_event.set() + + async def blocked_func() -> str: + await asyncio.sleep(999) + return 'should not reach' + + with self.assertRaises(asyncio.CancelledError): + _run_on_thread_loop( + _cancellable_run(cancel_event, blocked_func()), + ) + + # Loop must remain usable after a cancelled request. + async def quick() -> str: + return 'ok' + + self.assertEqual( + _run_on_thread_loop(_cancellable_run(threading.Event(), quick())), + 'ok', + ) + + +if __name__ == '__main__': + unittest.main()