Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Changelog
=========

Changes in Version 4.17.0 (2026/XX/XX)
--------------------------------------

PyMongo 4.17 brings a number of changes including:

- Added the :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.bind` and :meth:`~pymongo.client_session.ClientSession.bind` methods
that allow users to bind a session to all database operations within the scope of a context manager instead of having to explicitly pass the session to each individual operation.
See <PLACEHOLDER> for examples and more information.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<PLACEHOLDER> should be the MongoDB docs ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah once we have examples and such added to a page I'll update these spots.


Changes in Version 4.16.0 (2026/01/07)
--------------------------------------

Expand Down
41 changes: 41 additions & 0 deletions pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar, Token
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -181,6 +182,28 @@

_IS_SYNC = False

_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)


class _AsyncBoundSessionContext:
"""Context manager returned by AsyncClientSession.bind() that manages bound state."""

def __init__(self, session: AsyncClientSession, end_session: bool) -> None:
self._session = session
self._session_token: Optional[Token[AsyncClientSession]] = None
self._end_session = end_session

async def __aenter__(self) -> AsyncClientSession:
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
return self._session

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._session_token:
_SESSION.reset(self._session_token) # type: ignore[arg-type]
self._session_token = None
if self._end_session:
await self._session.end_session()


class SessionOptions:
"""Options for a new :class:`AsyncClientSession`.
Expand Down Expand Up @@ -547,6 +570,24 @@ def _check_ended(self) -> None:
if self._server_session is None:
raise InvalidOperation("Cannot use ended session")

def bind(self, end_session: bool = True) -> _AsyncBoundSessionContext:
"""Bind this session so it is implicitly passed to all database operations within the returned context.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a short docs explain here? We still do that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like a .. code-block:: python example


.. code-block:: python

async with client.start_session() as s:
async with s.bind():
# session=s is passed implicitly
await client.db.collection.insert_one({"x": 1})

:param end_session: Whether to end the session on exiting the returned context. Defaults to True.
If set to False, :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.end_session()` must be called
once the session is no longer used.

.. versionadded:: 4.17
"""
return _AsyncBoundSessionContext(self, end_session)

async def __aenter__(self) -> AsyncClientSession:
return self

Expand Down
30 changes: 23 additions & 7 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from pymongo.asynchronous import client_session, database, uri_parser
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
Expand Down Expand Up @@ -1408,7 +1408,8 @@ def start_session(
def _ensure_session(
self, session: Optional[AsyncClientSession] = None
) -> Optional[AsyncClientSession]:
"""If provided session is None, lend a temporary session."""
"""If provided session and bound session are None, lend a temporary session."""
session = session or self._get_bound_session()
if session:
return session

Expand Down Expand Up @@ -2267,11 +2268,14 @@ async def _tmp_session(
self, session: Optional[client_session.AsyncClientSession]
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]:
"""If provided session is None, lend a temporary session."""
if session is not None:
if not isinstance(session, client_session.AsyncClientSession):
raise ValueError(
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
)
if session is not None and not isinstance(session, client_session.AsyncClientSession):
raise ValueError(
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
)

# Check for a bound session. If one exists, treat it as an explicitly passed session.
session = session or self._get_bound_session()
if session:
# Don't call end_session.
yield session
return
Expand Down Expand Up @@ -2301,6 +2305,18 @@ async def _process_response(
if session is not None:
session._process_response(reply)

def _get_bound_session(self) -> Optional[AsyncClientSession]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is encapsulated in a separate utility function because cursor operations call _ensure_session directly, while most other database operations call _tmp_session instead. Separating out this check keeps the clarity of the current behavioral differences with minimal code duplication.

bound_session = _SESSION.get()
if bound_session:
if bound_session.client is self:
return bound_session
else:
raise InvalidOperation(
"Only the client that created the bound session can perform operations within its context block. See <PLACEHOLDER> for more information."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another <PLACEHOLDER> here … assuming that maybe these come out after merge?

)
else:
return None

async def server_info(
self, session: Optional[client_session.AsyncClientSession] = None
) -> dict[str, Any]:
Expand Down
41 changes: 41 additions & 0 deletions pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar, Token
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -180,6 +181,28 @@

_IS_SYNC = True

_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None)


class _BoundSessionContext:
"""Context manager returned by ClientSession.bind() that manages bound state."""

def __init__(self, session: ClientSession, end_session: bool) -> None:
self._session = session
self._session_token: Optional[Token[ClientSession]] = None
self._end_session = end_session

def __enter__(self) -> ClientSession:
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
return self._session

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._session_token:
_SESSION.reset(self._session_token) # type: ignore[arg-type]
self._session_token = None
if self._end_session:
self._session.end_session()


class SessionOptions:
"""Options for a new :class:`ClientSession`.
Expand Down Expand Up @@ -546,6 +569,24 @@ def _check_ended(self) -> None:
if self._server_session is None:
raise InvalidOperation("Cannot use ended session")

def bind(self, end_session: bool = True) -> _BoundSessionContext:
"""Bind this session so it is implicitly passed to all database operations within the returned context.

.. code-block:: python

with client.start_session() as s:
with s.bind():
# session=s is passed implicitly
client.db.collection.insert_one({"x": 1})

:param end_session: Whether to end the session on exiting the returned context. Defaults to True.
If set to False, :meth:`~pymongo.client_session.ClientSession.end_session()` must be called
once the session is no longer used.

.. versionadded:: 4.17
"""
return _BoundSessionContext(self, end_session)

def __enter__(self) -> ClientSession:
return self

Expand Down
30 changes: 23 additions & 7 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
from pymongo.synchronous import client_session, database, uri_parser
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _EmptyServerSession
from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext
Expand Down Expand Up @@ -1406,7 +1406,8 @@ def start_session(
)

def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:
"""If provided session is None, lend a temporary session."""
"""If provided session and bound session are None, lend a temporary session."""
session = session or self._get_bound_session()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what's going on here before the changes. How does:

if session:
    return session

lend a temporary session?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We call _ensure_session(session) in a bunch of places where session is an actual explicit session if the user passed one to the operation. The former if session: return session check is to return that explicit session if it exists instead of lending a temporary one.

if session:
return session

Expand Down Expand Up @@ -2263,11 +2264,14 @@ def _tmp_session(
self, session: Optional[client_session.ClientSession]
) -> Generator[Optional[client_session.ClientSession], None]:
"""If provided session is None, lend a temporary session."""
if session is not None:
if not isinstance(session, client_session.ClientSession):
raise ValueError(
f"'session' argument must be a ClientSession or None, not {type(session)}"
)
if session is not None and not isinstance(session, client_session.ClientSession):
raise ValueError(
f"'session' argument must be a ClientSession or None, not {type(session)}"
)

# Check for a bound session. If one exists, treat it as an explicitly passed session.
session = session or self._get_bound_session()
if session:
# Don't call end_session.
yield session
return
Expand Down Expand Up @@ -2295,6 +2299,18 @@ def _process_response(self, reply: Mapping[str, Any], session: Optional[ClientSe
if session is not None:
session._process_response(reply)

def _get_bound_session(self) -> Optional[ClientSession]:
bound_session = _SESSION.get()
if bound_session:
if bound_session.client is self:
return bound_session
else:
raise InvalidOperation(
"Only the client that created the bound session can perform operations within its context block. See <PLACEHOLDER> for more information."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Third <PLACEHOLDER> maybe synchro'd

)
else:
return None

def server_info(self, session: Optional[client_session.ClientSession] = None) -> dict[str, Any]:
"""Get information about the MongoDB server we're connected to.

Expand Down
113 changes: 113 additions & 0 deletions test/asynchronous/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,52 @@ async def _test_ops(self, client, *ops):
f"{f.__name__} did not return implicit session to pool",
)

# Explicit bound session
for f, args, kw in ops:
async with client.start_session() as s:
async with s.bind():
listener.reset()
s._materialize()
last_use = s._server_session.last_use
start = time.monotonic()
self.assertLessEqual(last_use, start)
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
await f(*args, **kw)
self.assertGreaterEqual(len(listener.started_events), 1)
for event in listener.started_events:
self.assertIn(
"lsid",
event.command,
f"{f.__name__} sent no lsid with {event.command_name}",
)

self.assertEqual(
s.session_id,
event.command["lsid"],
f"{f.__name__} sent wrong lsid with {event.command_name}",
)

self.assertFalse(s.has_ended)

self.assertTrue(s.has_ended)
with self.assertRaisesRegex(InvalidOperation, "ended session"):
async with s.bind():
await f(*args, **kw)

# Test a session cannot be used on another client.
async with self.client2.start_session() as s:
async with s.bind():
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
with self.assertRaisesRegex(
InvalidOperation,
"Only the client that created the bound session can perform operations within its context block",
):
await f(*args, **kw)

async def test_implicit_sessions_checkout(self):
# "To confirm that implicit sessions only allocate their server session after a
# successful connection checkout" test from Driver Sessions Spec.
Expand Down Expand Up @@ -825,6 +871,73 @@ async def test_session_not_copyable(self):
async with client.start_session() as s:
self.assertRaises(TypeError, lambda: copy.copy(s))

async def test_nested_session_binding(self):
coll = self.client.pymongo_test.test
await coll.insert_one({"x": 1})

session1 = self.client.start_session()
session2 = self.client.start_session()
session1._materialize()
session2._materialize()
try:
self.listener.reset()
# Uses implicit session
await coll.find_one()
implicit_lsid = self.listener.started_events[0].command.get("lsid")
self.assertIsNotNone(implicit_lsid)
self.assertNotEqual(implicit_lsid, session1.session_id)
self.assertNotEqual(implicit_lsid, session2.session_id)

async with session1.bind(end_session=False):
self.listener.reset()
# Uses bound session1
await coll.find_one()
session1_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session1_lsid, session1.session_id)

async with session2.bind(end_session=False):
self.listener.reset()
# Uses bound session2
await coll.find_one()
session2_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session2_lsid, session2.session_id)
self.assertNotEqual(session2_lsid, session1.session_id)

self.listener.reset()
# Use bound session1 again
await coll.find_one()
session1_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session1_lsid, session1.session_id)
self.assertNotEqual(session1_lsid, session2.session_id)

self.listener.reset()
# Uses implicit session
await coll.find_one()
implicit_lsid = self.listener.started_events[0].command.get("lsid")
self.assertIsNotNone(implicit_lsid)
self.assertNotEqual(implicit_lsid, session1.session_id)
self.assertNotEqual(implicit_lsid, session2.session_id)

finally:
await session1.end_session()
await session2.end_session()

async def test_session_binding_end_session(self):
coll = self.client.pymongo_test.test
await coll.insert_one({"x": 1})

async with self.client.start_session().bind() as s1:
await coll.find_one()

self.assertTrue(s1.has_ended)

async with self.client.start_session().bind(end_session=False) as s2:
await coll.find_one()

self.assertFalse(s2.has_ended)

await s2.end_session()


class TestCausalConsistency(AsyncUnitTest):
listener: SessionTestListener
Expand Down
Loading
Loading