Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@
from pymongo.errors import (
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
PyMongoError,
WTimeoutError,
Expand Down Expand Up @@ -480,14 +482,20 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:
_BACKOFF_INITIAL = 0.005 # 5ms initial backoff


def _within_time_limit(start_time: float) -> bool:
def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
"""Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
return False
return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT


def _would_exceed_time_limit(start_time: float, backoff: float) -> bool:
"""Is the backoff within the with_transaction retry limit?"""
return time.monotonic() + backoff - start_time >= _WITH_TRANSACTION_RETRY_TIME_LIMIT
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50})
else:
return NetworkTimeout(str(error))


_T = TypeVar("_T")
Expand Down Expand Up @@ -722,9 +730,9 @@ async def callback(session, custom_arg, custom_kwarg=None):
if retry: # Implement exponential backoff on retry.
jitter = random.random() # noqa: S311
backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX)
if _would_exceed_time_limit(start_time, backoff):
if not _within_time_limit(start_time, backoff):
assert last_error is not None
raise last_error
raise _make_timeout_error(last_error) from last_error
await asyncio.sleep(backoff)
retry += 1
await self.start_transaction(
Expand All @@ -737,13 +745,13 @@ async def callback(session, custom_arg, custom_kwarg=None):
last_error = exc
if self.in_transaction:
await self.abort_transaction()
if (
isinstance(exc, PyMongoError)
and exc.has_error_label("TransientTransactionError")
and _within_time_limit(start_time)
if isinstance(exc, PyMongoError) and exc.has_error_label(
"TransientTransactionError"
):
# Retry the entire transaction.
continue
if _within_time_limit(start_time):
# Retry the entire transaction.
continue
raise _make_timeout_error(last_error) from exc
raise

if not self.in_transaction:
Expand All @@ -754,17 +762,16 @@ async def callback(session, custom_arg, custom_kwarg=None):
try:
await self.commit_transaction()
except PyMongoError as exc:
if (
exc.has_error_label("UnknownTransactionCommitResult")
and _within_time_limit(start_time)
and not _max_time_expired_error(exc)
):
last_error = exc
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
if exc.has_error_label(
"UnknownTransactionCommitResult"
) and not _max_time_expired_error(exc):
# Retry the commit.
continue

if exc.has_error_label("TransientTransactionError") and _within_time_limit(
start_time
):
if exc.has_error_label("TransientTransactionError"):
# Retry the entire transaction.
break
raise
Expand Down
49 changes: 28 additions & 21 deletions pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@
from pymongo.errors import (
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
PyMongoError,
WTimeoutError,
Expand Down Expand Up @@ -478,14 +480,20 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:
_BACKOFF_INITIAL = 0.005 # 5ms initial backoff


def _within_time_limit(start_time: float) -> bool:
def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
"""Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
return False
return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT


def _would_exceed_time_limit(start_time: float, backoff: float) -> bool:
"""Is the backoff within the with_transaction retry limit?"""
return time.monotonic() + backoff - start_time >= _WITH_TRANSACTION_RETRY_TIME_LIMIT
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50})
else:
return NetworkTimeout(str(error))


_T = TypeVar("_T")
Expand Down Expand Up @@ -720,9 +728,9 @@ def callback(session, custom_arg, custom_kwarg=None):
if retry: # Implement exponential backoff on retry.
jitter = random.random() # noqa: S311
backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX)
if _would_exceed_time_limit(start_time, backoff):
if not _within_time_limit(start_time, backoff):
assert last_error is not None
raise last_error
raise _make_timeout_error(last_error) from last_error
time.sleep(backoff)
retry += 1
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
Expand All @@ -733,13 +741,13 @@ def callback(session, custom_arg, custom_kwarg=None):
last_error = exc
if self.in_transaction:
self.abort_transaction()
if (
isinstance(exc, PyMongoError)
and exc.has_error_label("TransientTransactionError")
and _within_time_limit(start_time)
if isinstance(exc, PyMongoError) and exc.has_error_label(
"TransientTransactionError"
):
# Retry the entire transaction.
continue
if _within_time_limit(start_time):
# Retry the entire transaction.
continue
raise _make_timeout_error(last_error) from exc
raise

if not self.in_transaction:
Expand All @@ -750,17 +758,16 @@ def callback(session, custom_arg, custom_kwarg=None):
try:
self.commit_transaction()
except PyMongoError as exc:
if (
exc.has_error_label("UnknownTransactionCommitResult")
and _within_time_limit(start_time)
and not _max_time_expired_error(exc)
):
last_error = exc
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
if exc.has_error_label(
"UnknownTransactionCommitResult"
) and not _max_time_expired_error(exc):
# Retry the commit.
continue

if exc.has_error_label("TransientTransactionError") and _within_time_limit(
start_time
):
if exc.has_error_label("TransientTransactionError"):
# Retry the entire transaction.
break
raise
Expand Down
41 changes: 38 additions & 3 deletions test/asynchronous/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import time
from io import BytesIO

import pymongo
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
from pymongo.asynchronous.pool import PoolState
from pymongo.server_selectors import writable_server_selector
Expand All @@ -47,7 +48,9 @@
CollectionInvalid,
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
)
from pymongo.operations import IndexModel, InsertOne
Expand Down Expand Up @@ -497,7 +500,7 @@ async def callback(session):
listener.reset()
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
with self.assertRaises(NetworkTimeout):
await s.with_transaction(callback)

self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
Expand Down Expand Up @@ -531,7 +534,7 @@ async def callback(session):

async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
with self.assertRaises(NetworkTimeout):
await s.with_transaction(callback)

self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
Expand Down Expand Up @@ -562,7 +565,7 @@ async def callback(session):

async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(ConnectionFailure):
with self.assertRaises(NetworkTimeout):
await s.with_transaction(callback)

# One insert for the callback and two commits (includes the automatic
Expand All @@ -571,6 +574,38 @@ async def callback(session):
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
)

@async_client_context.require_transactions
async def test_callback_not_retried_after_csot_timeout(self):
listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test

async def callback(session):
await coll.insert_one({}, session=session)
err: dict = {
"ok": 0,
"errmsg": "Transaction 7819 has been aborted.",
"code": 251,
"codeName": "NoSuchTransaction",
"errorLabels": ["TransientTransactionError"],
}
raise OperationFailure(err["errmsg"], err["code"], err)

# Create the collection.
await coll.insert_one({})
listener.reset()
async with client.start_session() as s:
with pymongo.timeout(1.0):
with self.assertRaises(ExecutionTimeout):
await s.with_transaction(callback)

# At least two attempts: the original and one or more retries.
inserts = len([x for x in listener.started_command_names() if x == "insert"])
aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"])

self.assertGreaterEqual(inserts, 2)
self.assertGreaterEqual(aborts, 2)

# Tested here because this supports Motor's convenient transactions API.
@async_client_context.require_transactions
async def test_in_transaction_property(self):
Expand Down
41 changes: 38 additions & 3 deletions test/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import time
from io import BytesIO

import pymongo
from gridfs.synchronous.grid_file import GridFS, GridFSBucket
from pymongo.server_selectors import writable_server_selector
from pymongo.synchronous.pool import PoolState
Expand All @@ -42,7 +43,9 @@
CollectionInvalid,
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
)
from pymongo.operations import IndexModel, InsertOne
Expand Down Expand Up @@ -489,7 +492,7 @@ def callback(session):
listener.reset()
with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
with self.assertRaises(NetworkTimeout):
s.with_transaction(callback)

self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
Expand Down Expand Up @@ -521,7 +524,7 @@ def callback(session):

with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
with self.assertRaises(NetworkTimeout):
s.with_transaction(callback)

self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
Expand Down Expand Up @@ -550,7 +553,7 @@ def callback(session):

with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(ConnectionFailure):
with self.assertRaises(NetworkTimeout):
s.with_transaction(callback)

# One insert for the callback and two commits (includes the automatic
Expand All @@ -559,6 +562,38 @@ def callback(session):
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
)

@client_context.require_transactions
def test_callback_not_retried_after_csot_timeout(self):
listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test

def callback(session):
coll.insert_one({}, session=session)
err: dict = {
"ok": 0,
"errmsg": "Transaction 7819 has been aborted.",
"code": 251,
"codeName": "NoSuchTransaction",
"errorLabels": ["TransientTransactionError"],
}
raise OperationFailure(err["errmsg"], err["code"], err)

# Create the collection.
coll.insert_one({})
listener.reset()
with client.start_session() as s:
with pymongo.timeout(1.0):
with self.assertRaises(ExecutionTimeout):
s.with_transaction(callback)

# At least two attempts: the original and one or more retries.
inserts = len([x for x in listener.started_command_names() if x == "insert"])
aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"])

self.assertGreaterEqual(inserts, 2)
self.assertGreaterEqual(aborts, 2)

# Tested here because this supports Motor's convenient transactions API.
@client_context.require_transactions
def test_in_transaction_property(self):
Expand Down
Loading