diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index 82f506f039..1e9022e721 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -87,7 +87,7 @@ def wake(self) -> None: """Execute the target function soon.""" self._event = True - def update_interval(self, new_interval: int) -> None: + def update_interval(self, new_interval: float) -> None: self._interval = new_interval def skip_sleep(self) -> None: @@ -217,7 +217,7 @@ def wake(self) -> None: """Execute the target function soon.""" self._event = True - def update_interval(self, new_interval: int) -> None: + def update_interval(self, new_interval: float) -> None: self._interval = new_interval def skip_sleep(self) -> None: diff --git a/test/asynchronous/test_periodic_executor.py b/test/asynchronous/test_periodic_executor.py new file mode 100644 index 0000000000..696955d40c --- /dev/null +++ b/test/asynchronous/test_periodic_executor.py @@ -0,0 +1,184 @@ +# Copyright 2026-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for periodic_executor.py.""" + +from __future__ import annotations + +import asyncio +import sys +import threading +import time + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncUnitTest, unittest + +from pymongo.periodic_executor import AsyncPeriodicExecutor + +_IS_SYNC = False + + +class TestAsyncPeriodicExecutor(AsyncUnitTest): + def _make_executor(self, interval=30.0, min_interval=0.01, target=None, name="test"): + if target is None: + + async def target(): + return True + + executor = AsyncPeriodicExecutor( + interval=interval, min_interval=min_interval, target=target, name=name + ) + self.addAsyncCleanup(self._close_executor, executor) + return executor + + async def _close_executor(self, executor): + executor.close() + await executor.join(timeout=2) + + async def test_join_without_open_is_safe(self): + executor = self._make_executor() + try: + await executor.join(timeout=0.01) + except Exception as e: + self.fail(f"join() raised unexpected Exception {e}") + + async def test_target_returning_false_stops_executor(self): + if _IS_SYNC: + ran = threading.Event() + else: + ran = asyncio.Event() + + async def target(): + ran.set() + return False + + executor = self._make_executor(target=target) + executor.open() + await executor.join(timeout=2) + self.assertTrue(ran.is_set(), "target never ran") + + async def test_skip_sleep_flag_skips_interval(self): + call_times = [] + + async def target(): + nonlocal call_times + call_times.append(time.monotonic()) + if len(call_times) >= 2: + return False + return True + + executor = self._make_executor(interval=30.0, min_interval=0.001, target=target) + executor.skip_sleep() + executor.open() + await executor.join(timeout=3) + self.assertGreaterEqual(len(call_times), 2) + self.assertLess(call_times[1] - call_times[0], 5.0) + + async def test_wake_causes_early_run(self): + call_count = 0 + if _IS_SYNC: + woken = threading.Event() + else: + woken = asyncio.Event() + + async def target(): + nonlocal call_count + call_count += 1 + if call_count == 1: + woken.set() + return call_count < 2 + + executor = self._make_executor(interval=30.0, min_interval=0.01, target=target) + executor.open() + if _IS_SYNC: + woken.wait(timeout=2) + else: + assert isinstance(woken, asyncio.Event) + await asyncio.wait_for(woken.wait(), timeout=2) + executor.wake() + await executor.join(timeout=3) + self.assertGreaterEqual(call_count, 2) + + async def test_update_interval_changes_next_wait(self): + call_times = [] + + async def target(): + nonlocal call_times + call_times.append(time.monotonic()) + if len(call_times) == 1: + # Shorten the interval from 30s so the next run happens promptly. + executor.update_interval(0.05) + return True + return False + + executor = self._make_executor(interval=30.0, min_interval=0.01, target=target) + executor.open() + await executor.join(timeout=3) + self.assertGreaterEqual(len(call_times), 2) + self.assertLess(call_times[1] - call_times[0], 5.0) + + async def test_open_after_target_returns_false(self): + called = 0 + + async def target(): + nonlocal called + called += 1 + return False + + executor = self._make_executor(target=target) + executor.open() + await executor.join(timeout=2) + executor.open() + await executor.join(timeout=2) + self.assertGreaterEqual(called, 2) + + async def test_target_exception_stops_executor(self): + call_count = 0 + + async def target(): + nonlocal call_count + call_count += 1 + raise RuntimeError("error") + + executor = self._make_executor(target=target) + + if _IS_SYNC: + # The exception re-raises on the executor's background thread, + # which would otherwise trigger threading.excepthook and print a + # noisy traceback. Swap it for a no-op for the duration of the test. + original_excepthook = threading.excepthook + threading.excepthook = lambda args: None + self.addCleanup(setattr, threading, "excepthook", original_excepthook) + + executor.open() + await executor.join(timeout=2) + if not _IS_SYNC and executor._task is not None and executor._task.done(): + # Retrieve the exception to avoid "Task exception was never + # retrieved" warnings when the task is garbage collected. + executor._task.exception() + self.assertEqual(call_count, 1, "target should stop after raising") + + # Re-opening after an exception restarts the executor. For the threaded + # PeriodicExecutor this also exercises the _thread_will_exit join path + # in open(). + executor.open() + await executor.join(timeout=2) + if not _IS_SYNC and executor._task is not None and executor._task.done(): + executor._task.exception() + self.assertEqual(call_count, 2, "executor should run again after re-open") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_periodic_executor.py b/test/test_periodic_executor.py new file mode 100644 index 0000000000..0797e4c71e --- /dev/null +++ b/test/test_periodic_executor.py @@ -0,0 +1,184 @@ +# Copyright 2026-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for periodic_executor.py.""" + +from __future__ import annotations + +import asyncio +import sys +import threading +import time + +sys.path[0:0] = [""] + +from test import UnitTest, unittest + +from pymongo.periodic_executor import PeriodicExecutor + +_IS_SYNC = True + + +class TestPeriodicExecutor(UnitTest): + def _make_executor(self, interval=30.0, min_interval=0.01, target=None, name="test"): + if target is None: + + def target(): + return True + + executor = PeriodicExecutor( + interval=interval, min_interval=min_interval, target=target, name=name + ) + self.addCleanup(self._close_executor, executor) + return executor + + def _close_executor(self, executor): + executor.close() + executor.join(timeout=2) + + def test_join_without_open_is_safe(self): + executor = self._make_executor() + try: + executor.join(timeout=0.01) + except Exception as e: + self.fail(f"join() raised unexpected Exception {e}") + + def test_target_returning_false_stops_executor(self): + if _IS_SYNC: + ran = threading.Event() + else: + ran = asyncio.Event() + + def target(): + ran.set() + return False + + executor = self._make_executor(target=target) + executor.open() + executor.join(timeout=2) + self.assertTrue(ran.is_set(), "target never ran") + + def test_skip_sleep_flag_skips_interval(self): + call_times = [] + + def target(): + nonlocal call_times + call_times.append(time.monotonic()) + if len(call_times) >= 2: + return False + return True + + executor = self._make_executor(interval=30.0, min_interval=0.001, target=target) + executor.skip_sleep() + executor.open() + executor.join(timeout=3) + self.assertGreaterEqual(len(call_times), 2) + self.assertLess(call_times[1] - call_times[0], 5.0) + + def test_wake_causes_early_run(self): + call_count = 0 + if _IS_SYNC: + woken = threading.Event() + else: + woken = asyncio.Event() + + def target(): + nonlocal call_count + call_count += 1 + if call_count == 1: + woken.set() + return call_count < 2 + + executor = self._make_executor(interval=30.0, min_interval=0.01, target=target) + executor.open() + if _IS_SYNC: + woken.wait(timeout=2) + else: + assert isinstance(woken, asyncio.Event) + asyncio.wait_for(woken.wait(), timeout=2) + executor.wake() + executor.join(timeout=3) + self.assertGreaterEqual(call_count, 2) + + def test_update_interval_changes_next_wait(self): + call_times = [] + + def target(): + nonlocal call_times + call_times.append(time.monotonic()) + if len(call_times) == 1: + # Shorten the interval from 30s so the next run happens promptly. + executor.update_interval(0.05) + return True + return False + + executor = self._make_executor(interval=30.0, min_interval=0.01, target=target) + executor.open() + executor.join(timeout=3) + self.assertGreaterEqual(len(call_times), 2) + self.assertLess(call_times[1] - call_times[0], 5.0) + + def test_open_after_target_returns_false(self): + called = 0 + + def target(): + nonlocal called + called += 1 + return False + + executor = self._make_executor(target=target) + executor.open() + executor.join(timeout=2) + executor.open() + executor.join(timeout=2) + self.assertGreaterEqual(called, 2) + + def test_target_exception_stops_executor(self): + call_count = 0 + + def target(): + nonlocal call_count + call_count += 1 + raise RuntimeError("error") + + executor = self._make_executor(target=target) + + if _IS_SYNC: + # The exception re-raises on the executor's background thread, + # which would otherwise trigger threading.excepthook and print a + # noisy traceback. Swap it for a no-op for the duration of the test. + original_excepthook = threading.excepthook + threading.excepthook = lambda args: None + self.addCleanup(setattr, threading, "excepthook", original_excepthook) + + executor.open() + executor.join(timeout=2) + if not _IS_SYNC and executor._task is not None and executor._task.done(): + # Retrieve the exception to avoid "Task exception was never + # retrieved" warnings when the task is garbage collected. + executor._task.exception() + self.assertEqual(call_count, 1, "target should stop after raising") + + # Re-opening after an exception restarts the executor. For the threaded + # PeriodicExecutor this also exercises the _thread_will_exit join path + # in open(). + executor.open() + executor.join(timeout=2) + if not _IS_SYNC and executor._task is not None and executor._task.done(): + executor._task.exception() + self.assertEqual(call_count, 2, "executor should run again after re-open") + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/synchro.py b/tools/synchro.py index 39250ab14a..13635a054a 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -252,6 +252,7 @@ def async_only_test(f: str) -> bool: "test_monitoring.py", "test_mongos_load_balancing.py", "test_on_demand_csfle.py", + "test_periodic_executor.py", "test_pooling.py", "test_raw_bson.py", "test_read_concern.py",