diff --git a/doc/changelog.rst b/doc/changelog.rst index 571ce3b63e..f38709203c 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,6 +1,15 @@ Changelog ========= +Changes in Version 4.17.0 (2026/XX/XX) +-------------------------------------- + +PyMongo 4.17 brings a number of changes including: + +- Added the :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.bind` and :meth:`~pymongo.client_session.ClientSession.bind` methods + that allow users to bind a session to all database operations within the scope of a context manager instead of having to explicitly pass the session to each individual operation. + See for examples and more information. + Changes in Version 4.16.0 (2026/01/07) -------------------------------------- diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index a12ca1f11b..c1e5a404d2 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -139,6 +139,7 @@ import time import uuid from collections.abc import Mapping as _Mapping +from contextvars import ContextVar, Token from typing import ( TYPE_CHECKING, Any, @@ -181,6 +182,28 @@ _IS_SYNC = False +_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None) + + +class _AsyncBoundSessionContext: + """Context manager returned by AsyncClientSession.bind() that manages bound state.""" + + def __init__(self, session: AsyncClientSession, end_session: bool) -> None: + self._session = session + self._session_token: Optional[Token[AsyncClientSession]] = None + self._end_session = end_session + + async def __aenter__(self) -> AsyncClientSession: + self._session_token = _SESSION.set(self._session) # type: ignore[assignment] + return self._session + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._session_token: + _SESSION.reset(self._session_token) # type: ignore[arg-type] + self._session_token = None + if self._end_session: + await self._session.end_session() + class SessionOptions: """Options for a new :class:`AsyncClientSession`. @@ -547,6 +570,24 @@ def _check_ended(self) -> None: if self._server_session is None: raise InvalidOperation("Cannot use ended session") + def bind(self, end_session: bool = True) -> _AsyncBoundSessionContext: + """Bind this session so it is implicitly passed to all database operations within the returned context. + + .. code-block:: python + + async with client.start_session() as s: + async with s.bind(): + # session=s is passed implicitly + await client.db.collection.insert_one({"x": 1}) + + :param end_session: Whether to end the session on exiting the returned context. Defaults to True. + If set to False, :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.end_session()` must be called + once the session is no longer used. + + .. versionadded:: 4.17 + """ + return _AsyncBoundSessionContext(self, end_session) + async def __aenter__(self) -> AsyncClientSession: return self diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 7fa0983908..95f2e3746e 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -65,7 +65,7 @@ from pymongo.asynchronous import client_session, database, uri_parser from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk -from pymongo.asynchronous.client_session import _EmptyServerSession +from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext @@ -1408,7 +1408,8 @@ def start_session( def _ensure_session( self, session: Optional[AsyncClientSession] = None ) -> Optional[AsyncClientSession]: - """If provided session is None, lend a temporary session.""" + """If provided session and bound session are None, lend a temporary session.""" + session = session or self._get_bound_session() if session: return session @@ -2267,11 +2268,14 @@ async def _tmp_session( self, session: Optional[client_session.AsyncClientSession] ) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]: """If provided session is None, lend a temporary session.""" - if session is not None: - if not isinstance(session, client_session.AsyncClientSession): - raise ValueError( - f"'session' argument must be an AsyncClientSession or None, not {type(session)}" - ) + if session is not None and not isinstance(session, client_session.AsyncClientSession): + raise ValueError( + f"'session' argument must be an AsyncClientSession or None, not {type(session)}" + ) + + # Check for a bound session. If one exists, treat it as an explicitly passed session. + session = session or self._get_bound_session() + if session: # Don't call end_session. yield session return @@ -2301,6 +2305,18 @@ async def _process_response( if session is not None: session._process_response(reply) + def _get_bound_session(self) -> Optional[AsyncClientSession]: + bound_session = _SESSION.get() + if bound_session: + if bound_session.client is self: + return bound_session + else: + raise InvalidOperation( + "Only the client that created the bound session can perform operations within its context block. See for more information." + ) + else: + return None + async def server_info( self, session: Optional[client_session.AsyncClientSession] = None ) -> dict[str, Any]: diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 8755e57261..5ef18a66bd 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -139,6 +139,7 @@ import time import uuid from collections.abc import Mapping as _Mapping +from contextvars import ContextVar, Token from typing import ( TYPE_CHECKING, Any, @@ -180,6 +181,28 @@ _IS_SYNC = True +_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None) + + +class _BoundSessionContext: + """Context manager returned by ClientSession.bind() that manages bound state.""" + + def __init__(self, session: ClientSession, end_session: bool) -> None: + self._session = session + self._session_token: Optional[Token[ClientSession]] = None + self._end_session = end_session + + def __enter__(self) -> ClientSession: + self._session_token = _SESSION.set(self._session) # type: ignore[assignment] + return self._session + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._session_token: + _SESSION.reset(self._session_token) # type: ignore[arg-type] + self._session_token = None + if self._end_session: + self._session.end_session() + class SessionOptions: """Options for a new :class:`ClientSession`. @@ -546,6 +569,24 @@ def _check_ended(self) -> None: if self._server_session is None: raise InvalidOperation("Cannot use ended session") + def bind(self, end_session: bool = True) -> _BoundSessionContext: + """Bind this session so it is implicitly passed to all database operations within the returned context. + + .. code-block:: python + + with client.start_session() as s: + with s.bind(): + # session=s is passed implicitly + client.db.collection.insert_one({"x": 1}) + + :param end_session: Whether to end the session on exiting the returned context. Defaults to True. + If set to False, :meth:`~pymongo.client_session.ClientSession.end_session()` must be called + once the session is no longer used. + + .. versionadded:: 4.17 + """ + return _BoundSessionContext(self, end_session) + def __enter__(self) -> ClientSession: return self diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index badbeac09d..161a28d48d 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -108,7 +108,7 @@ from pymongo.synchronous import client_session, database, uri_parser from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk -from pymongo.synchronous.client_session import _EmptyServerSession +from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext @@ -1406,7 +1406,8 @@ def start_session( ) def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]: - """If provided session is None, lend a temporary session.""" + """If provided session and bound session are None, lend a temporary session.""" + session = session or self._get_bound_session() if session: return session @@ -2263,11 +2264,14 @@ def _tmp_session( self, session: Optional[client_session.ClientSession] ) -> Generator[Optional[client_session.ClientSession], None]: """If provided session is None, lend a temporary session.""" - if session is not None: - if not isinstance(session, client_session.ClientSession): - raise ValueError( - f"'session' argument must be a ClientSession or None, not {type(session)}" - ) + if session is not None and not isinstance(session, client_session.ClientSession): + raise ValueError( + f"'session' argument must be a ClientSession or None, not {type(session)}" + ) + + # Check for a bound session. If one exists, treat it as an explicitly passed session. + session = session or self._get_bound_session() + if session: # Don't call end_session. yield session return @@ -2295,6 +2299,18 @@ def _process_response(self, reply: Mapping[str, Any], session: Optional[ClientSe if session is not None: session._process_response(reply) + def _get_bound_session(self) -> Optional[ClientSession]: + bound_session = _SESSION.get() + if bound_session: + if bound_session.client is self: + return bound_session + else: + raise InvalidOperation( + "Only the client that created the bound session can perform operations within its context block. See for more information." + ) + else: + return None + def server_info(self, session: Optional[client_session.ClientSession] = None) -> dict[str, Any]: """Get information about the MongoDB server we're connected to. diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 19ce868c56..404a69fdee 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -189,6 +189,52 @@ async def _test_ops(self, client, *ops): f"{f.__name__} did not return implicit session to pool", ) + # Explicit bound session + for f, args, kw in ops: + async with client.start_session() as s: + async with s.bind(): + listener.reset() + s._materialize() + last_use = s._server_session.last_use + start = time.monotonic() + self.assertLessEqual(last_use, start) + # In case "f" modifies its inputs. + args = copy.copy(args) + kw = copy.copy(kw) + await f(*args, **kw) + self.assertGreaterEqual(len(listener.started_events), 1) + for event in listener.started_events: + self.assertIn( + "lsid", + event.command, + f"{f.__name__} sent no lsid with {event.command_name}", + ) + + self.assertEqual( + s.session_id, + event.command["lsid"], + f"{f.__name__} sent wrong lsid with {event.command_name}", + ) + + self.assertFalse(s.has_ended) + + self.assertTrue(s.has_ended) + with self.assertRaisesRegex(InvalidOperation, "ended session"): + async with s.bind(): + await f(*args, **kw) + + # Test a session cannot be used on another client. + async with self.client2.start_session() as s: + async with s.bind(): + # In case "f" modifies its inputs. + args = copy.copy(args) + kw = copy.copy(kw) + with self.assertRaisesRegex( + InvalidOperation, + "Only the client that created the bound session can perform operations within its context block", + ): + await f(*args, **kw) + async def test_implicit_sessions_checkout(self): # "To confirm that implicit sessions only allocate their server session after a # successful connection checkout" test from Driver Sessions Spec. @@ -825,6 +871,73 @@ async def test_session_not_copyable(self): async with client.start_session() as s: self.assertRaises(TypeError, lambda: copy.copy(s)) + async def test_nested_session_binding(self): + coll = self.client.pymongo_test.test + await coll.insert_one({"x": 1}) + + session1 = self.client.start_session() + session2 = self.client.start_session() + session1._materialize() + session2._materialize() + try: + self.listener.reset() + # Uses implicit session + await coll.find_one() + implicit_lsid = self.listener.started_events[0].command.get("lsid") + self.assertIsNotNone(implicit_lsid) + self.assertNotEqual(implicit_lsid, session1.session_id) + self.assertNotEqual(implicit_lsid, session2.session_id) + + async with session1.bind(end_session=False): + self.listener.reset() + # Uses bound session1 + await coll.find_one() + session1_lsid = self.listener.started_events[0].command.get("lsid") + self.assertEqual(session1_lsid, session1.session_id) + + async with session2.bind(end_session=False): + self.listener.reset() + # Uses bound session2 + await coll.find_one() + session2_lsid = self.listener.started_events[0].command.get("lsid") + self.assertEqual(session2_lsid, session2.session_id) + self.assertNotEqual(session2_lsid, session1.session_id) + + self.listener.reset() + # Use bound session1 again + await coll.find_one() + session1_lsid = self.listener.started_events[0].command.get("lsid") + self.assertEqual(session1_lsid, session1.session_id) + self.assertNotEqual(session1_lsid, session2.session_id) + + self.listener.reset() + # Uses implicit session + await coll.find_one() + implicit_lsid = self.listener.started_events[0].command.get("lsid") + self.assertIsNotNone(implicit_lsid) + self.assertNotEqual(implicit_lsid, session1.session_id) + self.assertNotEqual(implicit_lsid, session2.session_id) + + finally: + await session1.end_session() + await session2.end_session() + + async def test_session_binding_end_session(self): + coll = self.client.pymongo_test.test + await coll.insert_one({"x": 1}) + + async with self.client.start_session().bind() as s1: + await coll.find_one() + + self.assertTrue(s1.has_ended) + + async with self.client.start_session().bind(end_session=False) as s2: + await coll.find_one() + + self.assertFalse(s2.has_ended) + + await s2.end_session() + class TestCausalConsistency(AsyncUnitTest): listener: SessionTestListener diff --git a/test/test_session.py b/test/test_session.py index 40d0a53afb..3963f88da0 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -189,6 +189,52 @@ def _test_ops(self, client, *ops): f"{f.__name__} did not return implicit session to pool", ) + # Explicit bound session + for f, args, kw in ops: + with client.start_session() as s: + with s.bind(): + listener.reset() + s._materialize() + last_use = s._server_session.last_use + start = time.monotonic() + self.assertLessEqual(last_use, start) + # In case "f" modifies its inputs. + args = copy.copy(args) + kw = copy.copy(kw) + f(*args, **kw) + self.assertGreaterEqual(len(listener.started_events), 1) + for event in listener.started_events: + self.assertIn( + "lsid", + event.command, + f"{f.__name__} sent no lsid with {event.command_name}", + ) + + self.assertEqual( + s.session_id, + event.command["lsid"], + f"{f.__name__} sent wrong lsid with {event.command_name}", + ) + + self.assertFalse(s.has_ended) + + self.assertTrue(s.has_ended) + with self.assertRaisesRegex(InvalidOperation, "ended session"): + with s.bind(): + f(*args, **kw) + + # Test a session cannot be used on another client. + with self.client2.start_session() as s: + with s.bind(): + # In case "f" modifies its inputs. + args = copy.copy(args) + kw = copy.copy(kw) + with self.assertRaisesRegex( + InvalidOperation, + "Only the client that created the bound session can perform operations within its context block", + ): + f(*args, **kw) + def test_implicit_sessions_checkout(self): # "To confirm that implicit sessions only allocate their server session after a # successful connection checkout" test from Driver Sessions Spec. @@ -825,6 +871,73 @@ def test_session_not_copyable(self): with client.start_session() as s: self.assertRaises(TypeError, lambda: copy.copy(s)) + def test_nested_session_binding(self): + coll = self.client.pymongo_test.test + coll.insert_one({"x": 1}) + + session1 = self.client.start_session() + session2 = self.client.start_session() + session1._materialize() + session2._materialize() + try: + self.listener.reset() + # Uses implicit session + coll.find_one() + implicit_lsid = self.listener.started_events[0].command.get("lsid") + self.assertIsNotNone(implicit_lsid) + self.assertNotEqual(implicit_lsid, session1.session_id) + self.assertNotEqual(implicit_lsid, session2.session_id) + + with session1.bind(end_session=False): + self.listener.reset() + # Uses bound session1 + coll.find_one() + session1_lsid = self.listener.started_events[0].command.get("lsid") + self.assertEqual(session1_lsid, session1.session_id) + + with session2.bind(end_session=False): + self.listener.reset() + # Uses bound session2 + coll.find_one() + session2_lsid = self.listener.started_events[0].command.get("lsid") + self.assertEqual(session2_lsid, session2.session_id) + self.assertNotEqual(session2_lsid, session1.session_id) + + self.listener.reset() + # Use bound session1 again + coll.find_one() + session1_lsid = self.listener.started_events[0].command.get("lsid") + self.assertEqual(session1_lsid, session1.session_id) + self.assertNotEqual(session1_lsid, session2.session_id) + + self.listener.reset() + # Uses implicit session + coll.find_one() + implicit_lsid = self.listener.started_events[0].command.get("lsid") + self.assertIsNotNone(implicit_lsid) + self.assertNotEqual(implicit_lsid, session1.session_id) + self.assertNotEqual(implicit_lsid, session2.session_id) + + finally: + session1.end_session() + session2.end_session() + + def test_session_binding_end_session(self): + coll = self.client.pymongo_test.test + coll.insert_one({"x": 1}) + + with self.client.start_session().bind() as s1: + coll.find_one() + + self.assertTrue(s1.has_ended) + + with self.client.start_session().bind(end_session=False) as s2: + coll.find_one() + + self.assertFalse(s2.has_ended) + + s2.end_session() + class TestCausalConsistency(UnitTest): listener: SessionTestListener diff --git a/tools/synchro.py b/tools/synchro.py index 5735d0052a..ee719d7429 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -37,6 +37,7 @@ "AsyncRawBatchCursor": "RawBatchCursor", "AsyncRawBatchCommandCursor": "RawBatchCommandCursor", "AsyncClientSession": "ClientSession", + "_AsyncBoundSessionContext": "_BoundSessionContext", "AsyncChangeStream": "ChangeStream", "AsyncCollectionChangeStream": "CollectionChangeStream", "AsyncDatabaseChangeStream": "DatabaseChangeStream",