From 72ba8bfd386251b43272fc69662210266e8d753d Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 4 Mar 2026 09:51:42 -0800 Subject: [PATCH 1/2] PYTHON-5716 - Clarify expected error if backoff exceeds CSOT's deadline in withTransaction --- pymongo/asynchronous/client_session.py | 49 +++++++++++++++----------- pymongo/synchronous/client_session.py | 49 +++++++++++++++----------- test/asynchronous/test_transactions.py | 41 +++++++++++++++++++-- test/test_transactions.py | 41 +++++++++++++++++++-- 4 files changed, 132 insertions(+), 48 deletions(-) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 5967651b53..c72e828849 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -163,7 +163,9 @@ from pymongo.errors import ( ConfigurationError, ConnectionFailure, + ExecutionTimeout, InvalidOperation, + NetworkTimeout, OperationFailure, PyMongoError, WTimeoutError, @@ -480,14 +482,20 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: _BACKOFF_INITIAL = 0.005 # 5ms initial backoff -def _within_time_limit(start_time: float) -> bool: +def _within_time_limit(start_time: float, backoff: float = 0) -> bool: """Are we within the with_transaction retry limit?""" - return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT + remaining = _csot.remaining() + if remaining is not None and remaining <= 0: + return False + return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT -def _would_exceed_time_limit(start_time: float, backoff: float) -> bool: - """Is the backoff within the with_transaction retry limit?""" - return time.monotonic() + backoff - start_time >= _WITH_TRANSACTION_RETRY_TIME_LIMIT +def _make_timeout_error(error: BaseException) -> PyMongoError: + """Convert error to a NetworkTimeout or ExecutionTimeout as appropriate.""" + if _csot.remaining() is not None: + return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}) + else: + return NetworkTimeout(str(error)) _T = TypeVar("_T") @@ -722,9 +730,9 @@ async def callback(session, custom_arg, custom_kwarg=None): if retry: # Implement exponential backoff on retry. jitter = random.random() # noqa: S311 backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX) - if _would_exceed_time_limit(start_time, backoff): + if not _within_time_limit(start_time, backoff): assert last_error is not None - raise last_error + raise _make_timeout_error(last_error) from last_error await asyncio.sleep(backoff) retry += 1 await self.start_transaction( @@ -737,13 +745,13 @@ async def callback(session, custom_arg, custom_kwarg=None): last_error = exc if self.in_transaction: await self.abort_transaction() - if ( - isinstance(exc, PyMongoError) - and exc.has_error_label("TransientTransactionError") - and _within_time_limit(start_time) + if isinstance(exc, PyMongoError) and exc.has_error_label( + "TransientTransactionError" ): - # Retry the entire transaction. - continue + if _within_time_limit(start_time): + # Retry the entire transaction. + continue + raise _make_timeout_error(last_error) from exc raise if not self.in_transaction: @@ -754,17 +762,16 @@ async def callback(session, custom_arg, custom_kwarg=None): try: await self.commit_transaction() except PyMongoError as exc: - if ( - exc.has_error_label("UnknownTransactionCommitResult") - and _within_time_limit(start_time) - and not _max_time_expired_error(exc) - ): + last_error = exc + if not _within_time_limit(start_time): + raise _make_timeout_error(last_error) from exc + if exc.has_error_label( + "UnknownTransactionCommitResult" + ) and not _max_time_expired_error(exc): # Retry the commit. continue - if exc.has_error_label("TransientTransactionError") and _within_time_limit( - start_time - ): + if exc.has_error_label("TransientTransactionError"): # Retry the entire transaction. break raise diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index dcda05dc46..2467bc71b3 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -160,7 +160,9 @@ from pymongo.errors import ( ConfigurationError, ConnectionFailure, + ExecutionTimeout, InvalidOperation, + NetworkTimeout, OperationFailure, PyMongoError, WTimeoutError, @@ -478,14 +480,20 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: _BACKOFF_INITIAL = 0.005 # 5ms initial backoff -def _within_time_limit(start_time: float) -> bool: +def _within_time_limit(start_time: float, backoff: float = 0) -> bool: """Are we within the with_transaction retry limit?""" - return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT + remaining = _csot.remaining() + if remaining is not None and remaining <= 0: + return False + return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT -def _would_exceed_time_limit(start_time: float, backoff: float) -> bool: - """Is the backoff within the with_transaction retry limit?""" - return time.monotonic() + backoff - start_time >= _WITH_TRANSACTION_RETRY_TIME_LIMIT +def _make_timeout_error(error: BaseException) -> PyMongoError: + """Convert error to a NetworkTimeout or ExecutionTimeout as appropriate.""" + if _csot.remaining() is not None: + return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}) + else: + return NetworkTimeout(str(error)) _T = TypeVar("_T") @@ -720,9 +728,9 @@ def callback(session, custom_arg, custom_kwarg=None): if retry: # Implement exponential backoff on retry. jitter = random.random() # noqa: S311 backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX) - if _would_exceed_time_limit(start_time, backoff): + if not _within_time_limit(start_time, backoff): assert last_error is not None - raise last_error + raise _make_timeout_error(last_error) from last_error time.sleep(backoff) retry += 1 self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) @@ -733,13 +741,13 @@ def callback(session, custom_arg, custom_kwarg=None): last_error = exc if self.in_transaction: self.abort_transaction() - if ( - isinstance(exc, PyMongoError) - and exc.has_error_label("TransientTransactionError") - and _within_time_limit(start_time) + if isinstance(exc, PyMongoError) and exc.has_error_label( + "TransientTransactionError" ): - # Retry the entire transaction. - continue + if _within_time_limit(start_time): + # Retry the entire transaction. + continue + raise _make_timeout_error(last_error) from exc raise if not self.in_transaction: @@ -750,17 +758,16 @@ def callback(session, custom_arg, custom_kwarg=None): try: self.commit_transaction() except PyMongoError as exc: - if ( - exc.has_error_label("UnknownTransactionCommitResult") - and _within_time_limit(start_time) - and not _max_time_expired_error(exc) - ): + last_error = exc + if not _within_time_limit(start_time): + raise _make_timeout_error(last_error) from exc + if exc.has_error_label( + "UnknownTransactionCommitResult" + ) and not _max_time_expired_error(exc): # Retry the commit. continue - if exc.has_error_label("TransientTransactionError") and _within_time_limit( - start_time - ): + if exc.has_error_label("TransientTransactionError"): # Retry the entire transaction. break raise diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 4e26c29618..30e8b4b2f2 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -21,6 +21,7 @@ import time from io import BytesIO +import pymongo from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket from pymongo.asynchronous.pool import PoolState from pymongo.server_selectors import writable_server_selector @@ -47,7 +48,9 @@ CollectionInvalid, ConfigurationError, ConnectionFailure, + ExecutionTimeout, InvalidOperation, + NetworkTimeout, OperationFailure, ) from pymongo.operations import IndexModel, InsertOne @@ -497,7 +500,7 @@ async def callback(session): listener.reset() async with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(OperationFailure): + with self.assertRaises(NetworkTimeout): await s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"]) @@ -531,7 +534,7 @@ async def callback(session): async with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(OperationFailure): + with self.assertRaises(NetworkTimeout): await s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"]) @@ -562,7 +565,7 @@ async def callback(session): async with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(ConnectionFailure): + with self.assertRaises(NetworkTimeout): await s.with_transaction(callback) # One insert for the callback and two commits (includes the automatic @@ -571,6 +574,38 @@ async def callback(session): listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"] ) + @async_client_context.require_transactions + async def test_callback_not_retried_after_csot_timeout(self): + listener = OvertCommandListener() + client = await self.async_rs_client(event_listeners=[listener]) + coll = client[self.db.name].test + + async def callback(session): + await coll.insert_one({}, session=session) + err: dict = { + "ok": 0, + "errmsg": "Transaction 7819 has been aborted.", + "code": 251, + "codeName": "NoSuchTransaction", + "errorLabels": ["TransientTransactionError"], + } + raise OperationFailure(err["errmsg"], err["code"], err) + + # Create the collection. + await coll.insert_one({}) + listener.reset() + async with client.start_session() as s: + with pymongo.timeout(0.1): + with self.assertRaises(ExecutionTimeout): + await s.with_transaction(callback) + + # At least two attempts: the original and one or more retries. + inserts = len([x for x in listener.started_command_names() if x == "insert"]) + aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"]) + + self.assertGreaterEqual(inserts, 2) + self.assertGreaterEqual(aborts, 2) + # Tested here because this supports Motor's convenient transactions API. @async_client_context.require_transactions async def test_in_transaction_property(self): diff --git a/test/test_transactions.py b/test/test_transactions.py index ff80745edc..16a355b758 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -21,6 +21,7 @@ import time from io import BytesIO +import pymongo from gridfs.synchronous.grid_file import GridFS, GridFSBucket from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.pool import PoolState @@ -42,7 +43,9 @@ CollectionInvalid, ConfigurationError, ConnectionFailure, + ExecutionTimeout, InvalidOperation, + NetworkTimeout, OperationFailure, ) from pymongo.operations import IndexModel, InsertOne @@ -489,7 +492,7 @@ def callback(session): listener.reset() with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(OperationFailure): + with self.assertRaises(NetworkTimeout): s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"]) @@ -521,7 +524,7 @@ def callback(session): with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(OperationFailure): + with self.assertRaises(NetworkTimeout): s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"]) @@ -550,7 +553,7 @@ def callback(session): with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(ConnectionFailure): + with self.assertRaises(NetworkTimeout): s.with_transaction(callback) # One insert for the callback and two commits (includes the automatic @@ -559,6 +562,38 @@ def callback(session): listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"] ) + @client_context.require_transactions + def test_callback_not_retried_after_csot_timeout(self): + listener = OvertCommandListener() + client = self.rs_client(event_listeners=[listener]) + coll = client[self.db.name].test + + def callback(session): + coll.insert_one({}, session=session) + err: dict = { + "ok": 0, + "errmsg": "Transaction 7819 has been aborted.", + "code": 251, + "codeName": "NoSuchTransaction", + "errorLabels": ["TransientTransactionError"], + } + raise OperationFailure(err["errmsg"], err["code"], err) + + # Create the collection. + coll.insert_one({}) + listener.reset() + with client.start_session() as s: + with pymongo.timeout(0.1): + with self.assertRaises(ExecutionTimeout): + s.with_transaction(callback) + + # At least two attempts: the original and one or more retries. + inserts = len([x for x in listener.started_command_names() if x == "insert"]) + aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"]) + + self.assertGreaterEqual(inserts, 2) + self.assertGreaterEqual(aborts, 2) + # Tested here because this supports Motor's convenient transactions API. @client_context.require_transactions def test_in_transaction_property(self): From d9c6c9d073452b23350a43beeeb0ef9636adb04c Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 4 Mar 2026 10:43:03 -0800 Subject: [PATCH 2/2] Increase CSOT timeout --- test/asynchronous/test_transactions.py | 2 +- test/test_transactions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 30e8b4b2f2..95a07a743c 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -595,7 +595,7 @@ async def callback(session): await coll.insert_one({}) listener.reset() async with client.start_session() as s: - with pymongo.timeout(0.1): + with pymongo.timeout(1.0): with self.assertRaises(ExecutionTimeout): await s.with_transaction(callback) diff --git a/test/test_transactions.py b/test/test_transactions.py index 16a355b758..9e370294ef 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -583,7 +583,7 @@ def callback(session): coll.insert_one({}) listener.reset() with client.start_session() as s: - with pymongo.timeout(0.1): + with pymongo.timeout(1.0): with self.assertRaises(ExecutionTimeout): s.with_transaction(callback)