From 4d10c3094e32419ca511285cf08e9c481d3cbe3c Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 4 Jun 2026 18:41:37 -0400 Subject: [PATCH 1/7] PYTHON-5676 Add command_runner.run_command; route network.command() through it Introduce pymongo/asynchronous/command_runner.py (auto-generates the sync mirror), the single async code path for command execution. run_command owns the full skeleton: STARTED/SUCCEEDED/FAILED command logging AND APM publishing together, the network round trip, $clusterTime gossip, _process_response, _check_command_response, failure conversion, and auto-encryption decryption. Route network.command() through it, removing the duplicated logging/APM/send/ receive/decrypt block. Behavior is preserved byte-for-byte (logging and APM event documents unchanged); no per-command object is allocated, so the hot path is unchanged. --- pymongo/asynchronous/command_runner.py | 254 +++++++++++++++++++++++++ pymongo/asynchronous/network.py | 175 +++-------------- pymongo/synchronous/command_runner.py | 254 +++++++++++++++++++++++++ pymongo/synchronous/network.py | 175 +++-------------- 4 files changed, 556 insertions(+), 302 deletions(-) create mode 100644 pymongo/asynchronous/command_runner.py create mode 100644 pymongo/synchronous/command_runner.py diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py new file mode 100644 index 0000000000..d70d926d84 --- /dev/null +++ b/pymongo/asynchronous/command_runner.py @@ -0,0 +1,254 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The single code path for executing a command over a connection. + +Every database operation -- standard commands, cursor ``find``/``getMore`` +operations, and (collection-level and client-level) bulk writes -- runs its +network round trip through :func:`run_command`. The function owns the entire +shared skeleton: command logging, APM event publishing, ``send``/``receive``, +``$clusterTime`` gossip, ``_process_response``, ``_check_command_response``, +failure conversion, and auto-encryption decryption. Callers supply only the +parts that vary (the encoded message and a handful of transport/output hooks). +""" +from __future__ import annotations + +import datetime +import logging +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + MutableMapping, + Optional, + Sequence, + Union, + cast, +) + +from bson import _decode_all_selective +from pymongo import helpers_shared +from pymongo.errors import NotPrimaryError, OperationFailure +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _convert_exception +from pymongo.network_layer import async_receive_message, async_sendall + +if TYPE_CHECKING: + from bson import CodecOptions + from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.mongo_client import AsyncMongoClient + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.message import _OpMsg, _OpReply + from pymongo.monitoring import _EventListeners + from pymongo.typings import _Address, _DocumentOut, _DocumentType + +_IS_SYNC = False + + +async def run_command( + conn: AsyncConnection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[AsyncMongoClient[Any]], + session: Optional[AsyncClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + unacknowledged: bool = False, + speculative_hello: bool = False, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. + + This is the single code path for command execution. It publishes the + ``STARTED``/``SUCCEEDED``/``FAILED`` command log and APM events, performs + the network round trip, gossips ``$clusterTime``, runs + ``client._process_response`` and ``_check_command_response``, and decrypts + the reply when auto-encryption is enabled. + + :param conn: The AsyncConnection to send on. + :param cmd: The command document, used for the ``STARTED`` log event. + :param dbname: The database the command runs against. + :param request_id: The request id of the encoded message. + :param msg: The encoded OP_MSG bytes to send. + :param client: The AsyncMongoClient, for ``$clusterTime`` gossip, logging, + and decryption. ``None`` disables those steps (e.g. during handshake). + :param session: The session to update from the response. + :param listeners: The event listeners, or ``None`` to disable APM. + :param address: The (host, port) of ``conn`` for APM events. + :param start: The ``datetime`` the operation began, for duration timing. + :param codec_options: The CodecOptions used to decode the reply. + :param user_fields: Response fields decoded with the codec's TypeDecoders. + :param orig: The command document published in the ``STARTED`` APM event; + defaults to ``cmd`` (differs only when the wire command was mutated, + e.g. with a read preference or after encryption). + :param op_id: The APM operation id; defaults to ``request_id``. + :param check: Raise OperationFailure on a command error. + :param allowable_errors: Errors to ignore when ``check`` is True. + :param parse_write_concern_error: Parse the ``writeConcernError`` field. + :param unacknowledged: True for an unacknowledged write: send only and fake + an ``{"ok": 1}`` reply. + :param speculative_hello: True if the command carried speculative auth, for + APM redaction. + """ + name = next(iter(cmd)) + if orig is None: + orig = cmd + publish = listeners is not None and listeners.enabled_for_commands + + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_start( + orig, + dbname, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + ) + + try: + await async_sendall(conn.conn.get_conn, msg) + if unacknowledged: + # Unacknowledged, fake a successful command response. + reply = None + docs: list[dict[str, Any]] = [{"ok": 1}] + else: + reply = await async_receive_message(conn, request_id) + conn.more_to_come = reply.more_to_come + docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + await client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + name, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + + duration = datetime.datetime.now() - start + response_doc = docs[0] + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=response_doc, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + speculative_authenticate="speculativeAuthenticate" in orig, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + response_doc, + name, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + if client and client._encrypter and reply: + decrypted = await client._encrypter.decrypt(reply.raw_command_response()) + docs = cast( + "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) + ) + + return docs, reply, duration diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 5a5dc7fa2c..ed86b4522f 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -16,7 +16,6 @@ from __future__ import annotations import datetime -import logging from typing import ( TYPE_CHECKING, Any, @@ -25,23 +24,13 @@ Optional, Sequence, Union, - cast, ) -from bson import _decode_all_selective -from pymongo import _csot, helpers_shared, message +from pymongo import _csot, message +from pymongo.asynchronous.command_runner import run_command from pymongo.compression_support import _NO_COMPRESSION -from pymongo.errors import ( - NotPrimaryError, - OperationFailure, -) -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate -from pymongo.network_layer import ( - async_receive_message, - async_sendall, -) if TYPE_CHECKING: from bson import CodecOptions @@ -52,7 +41,7 @@ from pymongo.monitoring import _EventListeners from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode - from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.typings import _Address, _CollationIn, _DocumentType from pymongo.write_concern import WriteConcern _IS_SYNC = False @@ -159,140 +148,24 @@ async def command( if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=spec, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - await async_sendall(conn.conn.get_conn, msg) - if use_op_msg and unacknowledged: - # Unacknowledged, fake a successful command response. - reply = None - response_doc: _DocumentOut = {"ok": 1} - else: - reply = await async_receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields - ) - - response_doc = unpacked_docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - await client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - ) - except Exception as exc: - duration = datetime.datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = message._convert_exception(exc) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = datetime.datetime.now() - start - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=response_doc, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - speculative_authenticate="speculativeAuthenticate" in orig, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - response_doc, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - speculative_hello=speculative_hello, - database_name=dbname, - ) - - if client and client._encrypter and reply: - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) - response_doc = cast( - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] - ) - - return response_doc # type: ignore[return-value] + docs, _, _ = await run_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + unacknowledged=use_op_msg and unacknowledged, + speculative_hello=speculative_hello, + ) + return docs[0] # type: ignore[return-value] diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py new file mode 100644 index 0000000000..2e272a5eab --- /dev/null +++ b/pymongo/synchronous/command_runner.py @@ -0,0 +1,254 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The single code path for executing a command over a connection. + +Every database operation -- standard commands, cursor ``find``/``getMore`` +operations, and (collection-level and client-level) bulk writes -- runs its +network round trip through :func:`run_command`. The function owns the entire +shared skeleton: command logging, APM event publishing, ``send``/``receive``, +``$clusterTime`` gossip, ``_process_response``, ``_check_command_response``, +failure conversion, and auto-encryption decryption. Callers supply only the +parts that vary (the encoded message and a handful of transport/output hooks). +""" +from __future__ import annotations + +import datetime +import logging +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + MutableMapping, + Optional, + Sequence, + Union, + cast, +) + +from bson import _decode_all_selective +from pymongo import helpers_shared +from pymongo.errors import NotPrimaryError, OperationFailure +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _convert_exception +from pymongo.network_layer import receive_message, sendall + +if TYPE_CHECKING: + from bson import CodecOptions + from pymongo.message import _OpMsg, _OpReply + from pymongo.monitoring import _EventListeners + from pymongo.synchronous.client_session import ClientSession + from pymongo.synchronous.mongo_client import MongoClient + from pymongo.synchronous.pool import Connection + from pymongo.typings import _Address, _DocumentOut, _DocumentType + +_IS_SYNC = True + + +def run_command( + conn: Connection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[MongoClient[Any]], + session: Optional[ClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + unacknowledged: bool = False, + speculative_hello: bool = False, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. + + This is the single code path for command execution. It publishes the + ``STARTED``/``SUCCEEDED``/``FAILED`` command log and APM events, performs + the network round trip, gossips ``$clusterTime``, runs + ``client._process_response`` and ``_check_command_response``, and decrypts + the reply when auto-encryption is enabled. + + :param conn: The Connection to send on. + :param cmd: The command document, used for the ``STARTED`` log event. + :param dbname: The database the command runs against. + :param request_id: The request id of the encoded message. + :param msg: The encoded OP_MSG bytes to send. + :param client: The MongoClient, for ``$clusterTime`` gossip, logging, + and decryption. ``None`` disables those steps (e.g. during handshake). + :param session: The session to update from the response. + :param listeners: The event listeners, or ``None`` to disable APM. + :param address: The (host, port) of ``conn`` for APM events. + :param start: The ``datetime`` the operation began, for duration timing. + :param codec_options: The CodecOptions used to decode the reply. + :param user_fields: Response fields decoded with the codec's TypeDecoders. + :param orig: The command document published in the ``STARTED`` APM event; + defaults to ``cmd`` (differs only when the wire command was mutated, + e.g. with a read preference or after encryption). + :param op_id: The APM operation id; defaults to ``request_id``. + :param check: Raise OperationFailure on a command error. + :param allowable_errors: Errors to ignore when ``check`` is True. + :param parse_write_concern_error: Parse the ``writeConcernError`` field. + :param unacknowledged: True for an unacknowledged write: send only and fake + an ``{"ok": 1}`` reply. + :param speculative_hello: True if the command carried speculative auth, for + APM redaction. + """ + name = next(iter(cmd)) + if orig is None: + orig = cmd + publish = listeners is not None and listeners.enabled_for_commands + + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_start( + orig, + dbname, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + ) + + try: + sendall(conn.conn.get_conn, msg) + if unacknowledged: + # Unacknowledged, fake a successful command response. + reply = None + docs: list[dict[str, Any]] = [{"ok": 1}] + else: + reply = receive_message(conn, request_id) + conn.more_to_come = reply.more_to_come + docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + name, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + + duration = datetime.datetime.now() - start + response_doc = docs[0] + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=response_doc, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + speculative_authenticate="speculativeAuthenticate" in orig, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + response_doc, + name, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + if client and client._encrypter and reply: + decrypted = client._encrypter.decrypt(reply.raw_command_response()) + docs = cast( + "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) + ) + + return docs, reply, duration diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index 7d9bca4d58..6576f1c5e6 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -16,7 +16,6 @@ from __future__ import annotations import datetime -import logging from typing import ( TYPE_CHECKING, Any, @@ -25,23 +24,13 @@ Optional, Sequence, Union, - cast, ) -from bson import _decode_all_selective -from pymongo import _csot, helpers_shared, message +from pymongo import _csot, message from pymongo.compression_support import _NO_COMPRESSION -from pymongo.errors import ( - NotPrimaryError, - OperationFailure, -) -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate -from pymongo.network_layer import ( - receive_message, - sendall, -) +from pymongo.synchronous.command_runner import run_command if TYPE_CHECKING: from bson import CodecOptions @@ -52,7 +41,7 @@ from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection - from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.typings import _Address, _CollationIn, _DocumentType from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -159,140 +148,24 @@ def command( if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=spec, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - sendall(conn.conn.get_conn, msg) - if use_op_msg and unacknowledged: - # Unacknowledged, fake a successful command response. - reply = None - response_doc: _DocumentOut = {"ok": 1} - else: - reply = receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields - ) - - response_doc = unpacked_docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - ) - except Exception as exc: - duration = datetime.datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = message._convert_exception(exc) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = datetime.datetime.now() - start - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=response_doc, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - speculative_authenticate="speculativeAuthenticate" in orig, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - response_doc, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - speculative_hello=speculative_hello, - database_name=dbname, - ) - - if client and client._encrypter and reply: - decrypted = client._encrypter.decrypt(reply.raw_command_response()) - response_doc = cast( - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] - ) - - return response_doc # type: ignore[return-value] + docs, _, _ = run_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + unacknowledged=use_op_msg and unacknowledged, + speculative_hello=speculative_hello, + ) + return docs[0] # type: ignore[return-value] From d169dcf89e3445472702c09341b29db1b7ba30c2 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 4 Jun 2026 18:55:19 -0400 Subject: [PATCH 2/7] PYTHON-5676 Route Server.run_operation() through run_command Extend run_command with the cursor transport (conn.send_message/receive_message, exhaust more_to_come receive-only) and output hooks (unpack_res, cursor_id, is_command_response for legacy OP_QUERY, pool_opts, command_name, ensure_db for $db gossip, and a reply_doc_builder for the find/getMore/explain APM reply format). run_operation now builds the message, supplies the reply-doc builder, and keeps the Response/PinnedResponse wrapping; everything between is the shared run_command path. The legacy OP_QUERY response shaping is preserved (is_command_response=use_cmd), not deleted -- that dead-code cleanup stays out of this consolidation. Behavior (logging, APM events, exhaust/pinning, decryption) is unchanged. --- pymongo/asynchronous/command_runner.py | 134 ++++++++++++----- pymongo/asynchronous/server.py | 190 ++++++------------------- pymongo/synchronous/command_runner.py | 134 ++++++++++++----- pymongo/synchronous/server.py | 190 ++++++------------------- 4 files changed, 294 insertions(+), 354 deletions(-) diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index d70d926d84..62529f42bc 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -14,13 +14,13 @@ """The single code path for executing a command over a connection. -Every database operation -- standard commands, cursor ``find``/``getMore`` -operations, and (collection-level and client-level) bulk writes -- runs its -network round trip through :func:`run_command`. The function owns the entire -shared skeleton: command logging, APM event publishing, ``send``/``receive``, -``$clusterTime`` gossip, ``_process_response``, ``_check_command_response``, -failure conversion, and auto-encryption decryption. Callers supply only the -parts that vary (the encoded message and a handful of transport/output hooks). +Every database operation -- standard commands and cursor ``find``/``getMore`` +operations -- runs its network round trip through :func:`run_command`. The +function owns the entire shared skeleton: command logging, APM event +publishing, ``send``/``receive``, ``$clusterTime`` gossip, +``_process_response``, ``_check_command_response``, failure conversion, and +auto-encryption decryption. Callers supply only the parts that vary (the +encoded message and a handful of transport/output hooks). """ from __future__ import annotations @@ -29,6 +29,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Mapping, MutableMapping, Optional, @@ -51,6 +52,7 @@ from pymongo.asynchronous.pool import AsyncConnection from pymongo.message import _OpMsg, _OpReply from pymongo.monitoring import _EventListeners + from pymongo.pool_options import PoolOptions from pymongo.typings import _Address, _DocumentOut, _DocumentType _IS_SYNC = False @@ -72,11 +74,24 @@ async def run_command( user_fields: Optional[Mapping[str, Any]] = None, orig: Optional[MutableMapping[str, Any]] = None, op_id: Optional[int] = None, + command_name: Optional[str] = None, check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, parse_write_concern_error: bool = False, + pool_opts: Optional[PoolOptions] = None, unacknowledged: bool = False, speculative_hello: bool = False, + ensure_db: bool = False, + use_conn_transport: bool = False, + max_doc_size: int = 0, + more_to_come: bool = False, + set_conn_more_to_come: bool = True, + is_command_response: bool = True, + unpack_res: Optional[Callable[..., Any]] = None, + cursor_id: Optional[int] = None, + reply_doc_builder: Optional[ + Callable[[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]]], _DocumentOut] + ] = None, ) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. @@ -87,10 +102,11 @@ async def run_command( the reply when auto-encryption is enabled. :param conn: The AsyncConnection to send on. - :param cmd: The command document, used for the ``STARTED`` log event. + :param cmd: The command document, used for the ``STARTED`` log/APM event. :param dbname: The database the command runs against. - :param request_id: The request id of the encoded message. - :param msg: The encoded OP_MSG bytes to send. + :param request_id: The request id of the encoded message (``0`` when + ``more_to_come`` and no message is sent). + :param msg: The encoded bytes to send (ignored when ``more_to_come``). :param client: The AsyncMongoClient, for ``$clusterTime`` gossip, logging, and decryption. ``None`` disables those steps (e.g. during handshake). :param session: The session to update from the response. @@ -103,15 +119,40 @@ async def run_command( defaults to ``cmd`` (differs only when the wire command was mutated, e.g. with a read preference or after encryption). :param op_id: The APM operation id; defaults to ``request_id``. + :param command_name: The command name for the ``SUCCEEDED``/``FAILED`` APM + events; defaults to the first key of ``cmd``. :param check: Raise OperationFailure on a command error. :param allowable_errors: Errors to ignore when ``check`` is True. :param parse_write_concern_error: Parse the ``writeConcernError`` field. + :param pool_opts: PoolOptions forwarded to ``_check_command_response`` (the + cursor path uses this in place of ``allowable_errors``). :param unacknowledged: True for an unacknowledged write: send only and fake an ``{"ok": 1}`` reply. :param speculative_hello: True if the command carried speculative auth, for APM redaction. + :param ensure_db: Add ``$db`` to the published command if missing (cursor + path), after the ``STARTED`` log has been emitted. + :param use_conn_transport: Send/receive via ``conn.send_message`` / + ``conn.receive_message`` (cursor path) instead of the raw + ``async_sendall`` / ``async_receive_message`` (network path). + :param max_doc_size: The largest document size, for ``conn.send_message``. + :param more_to_come: Receive only, without sending (exhaust ``getMore``). + :param set_conn_more_to_come: Store ``reply.more_to_come`` on ``conn`` (the + network/streaming-monitor path); the cursor path manages exhaust + separately and must leave ``conn.more_to_come`` untouched. + :param is_command_response: True if the reply is an OP_MSG command response + (``_process_response``/``_check_command_response``/decryption apply); + False for a legacy OP_QUERY cursor response. + :param unpack_res: A callable decoding the wire response (cursor path); when + ``None`` the reply's own ``unpack_response`` is used. + :param cursor_id: The cursor id passed to ``unpack_res``. + :param reply_doc_builder: Builds the reply document published in the + ``SUCCEEDED`` event from ``(docs, reply)`` (cursor find/getMore format); + when ``None`` the first decoded document is published. """ name = next(iter(cmd)) + if command_name is None: + command_name = name if orig is None: orig = cmd publish = listeners is not None and listeners.enabled_for_commands @@ -135,6 +176,8 @@ async def run_command( if publish: assert listeners is not None assert address is not None + if ensure_db and "$db" not in orig: + orig["$db"] = dbname listeners.publish_command_start( orig, dbname, @@ -145,30 +188,51 @@ async def run_command( service_id=conn.service_id, ) + reply: Optional[Union[_OpReply, _OpMsg]] try: - await async_sendall(conn.conn.get_conn, msg) - if unacknowledged: + if more_to_come: + reply = await conn.receive_message(None) + elif use_conn_transport: + await conn.send_message(msg, max_doc_size) + reply = await conn.receive_message(request_id) + elif unacknowledged: + await async_sendall(conn.conn.get_conn, msg) # Unacknowledged, fake a successful command response. reply = None docs: list[dict[str, Any]] = [{"ok": 1}] else: + await async_sendall(conn.conn.get_conn, msg) reply = await async_receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) - response_doc = docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - await client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, + + if reply is not None: + if set_conn_more_to_come: + conn.more_to_come = reply.more_to_come + if unpack_res is not None: + docs = unpack_res( + reply, + cursor_id, + codec_options, + legacy_response=not is_command_response, + user_fields=user_fields, ) + else: + docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + if is_command_response: + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + await client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + pool_opts=pool_opts, + ) except Exception as exc: duration = datetime.datetime.now() - start if isinstance(exc, (NotPrimaryError, OperationFailure)): @@ -199,7 +263,7 @@ async def run_command( listeners.publish_command_failure( duration, failure, - name, + command_name, request_id, address, conn.server_connection_id, @@ -210,14 +274,18 @@ async def run_command( raise duration = datetime.datetime.now() - start - response_doc = docs[0] + published_reply: _DocumentOut + if reply_doc_builder is not None: + published_reply = reply_doc_builder(docs, reply) + else: + published_reply = docs[0] if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, message=_CommandStatusMessage.SUCCEEDED, clientId=client._topology_settings._topology_id, durationMS=duration, - reply=response_doc, + reply=published_reply, commandName=name, databaseName=dbname, requestId=request_id, @@ -234,8 +302,8 @@ async def run_command( assert address is not None listeners.publish_command_success( duration, - response_doc, - name, + published_reply, + command_name, request_id, address, conn.server_connection_id, @@ -245,7 +313,7 @@ async def run_command( database_name=dbname, ) - if client and client._encrypter and reply: + if client and client._encrypter and reply and is_command_response: decrypted = await client._encrypter.decrypt(reply.raw_command_response()) docs = cast( "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index f212306174..b18cf56c52 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -26,18 +26,14 @@ Union, ) -from bson import _decode_all_selective +from pymongo.asynchronous.command_runner import run_command from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers_shared import _check_command_response from pymongo.logger import ( - _COMMAND_LOGGER, _SDAM_LOGGER, - _CommandStatusMessage, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query +from pymongo.message import _GetMore, _OpMsg, _OpReply, _Query from pymongo.response import PinnedResponse, Response if TYPE_CHECKING: @@ -158,7 +154,6 @@ async def run_operation( :param client: An AsyncMongoClient instance. """ assert listeners is not None - publish = listeners.enabled_for_commands start = datetime.now() use_cmd = operation.use_command(conn) @@ -166,151 +161,58 @@ async def run_operation( cmd, dbn = await self.operation_to_command(operation, conn, use_cmd) if more_to_come: request_id = 0 + data = b"" + max_doc_size = 0 else: message = operation.get_message(read_preference, conn, use_cmd) request_id, data, max_doc_size = self._split_message(message) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) + user_fields = _CURSOR_DOC_FIELDS if use_cmd else None - if publish: - if "$db" not in cmd: - cmd["$db"] = dbn - assert listeners is not None - listeners.publish_command_start( - cmd, - dbn, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - if more_to_come: - reply = await conn.receive_message(None) - else: - await conn.send_message(data, max_doc_size) - reply = await conn.receive_message(request_id) - - # Unpack and check for command errors. - if use_cmd: - user_fields = _CURSOR_DOC_FIELDS - legacy_response = False - else: - user_fields = None - legacy_response = True - docs = unpack_res( - reply, - operation.cursor_id, - operation.codec_options, - legacy_response=legacy_response, - user_fields=user_fields, - ) + def _build_reply_doc( + docs: list[dict[str, Any]], reply: Optional[Union[_OpReply, _OpMsg]] + ) -> _DocumentOut: + # Must publish in find / getMore / explain command response format. if use_cmd: - first = docs[0] - await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] - _check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type] - except Exception as exc: - duration = datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - listeners.publish_command_failure( - duration, - failure, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - raise - duration = datetime.now() - start - # Must publish in find / getMore / explain command response - # format. - if use_cmd: - res = docs[0] - elif operation.name == "explain": - res = docs[0] if docs else {} - else: - res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr] + return docs[0] + elif operation.name == "explain": + return docs[0] if docs else {} + res: dict[str, Any] = { + "cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, # type: ignore[union-attr] + "ok": 1, + } if operation.name == "find": res["cursor"]["firstBatch"] = docs else: res["cursor"]["nextBatch"] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=res, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - listeners.publish_command_success( - duration, - res, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - - # Decrypt response. - client = operation.client # type: ignore[assignment] - if client and client._encrypter: - if use_cmd: - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) - docs = _decode_all_selective(decrypted, operation.codec_options, user_fields) + return res + + docs, reply, duration = await run_command( + conn, + cmd, + dbn, + request_id, + data, + client=client, + session=operation.session, # type: ignore[arg-type] + listeners=listeners, + address=conn.address, + start=start, + codec_options=operation.codec_options, + user_fields=user_fields, + command_name=operation.name, + pool_opts=conn.opts, + ensure_db=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + more_to_come=bool(more_to_come), + set_conn_more_to_come=False, + is_command_response=use_cmd, + unpack_res=unpack_res, + cursor_id=operation.cursor_id, + reply_doc_builder=_build_reply_doc, + ) + assert reply is not None response: Response @@ -332,7 +234,7 @@ async def run_operation( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs, + docs=docs, # type: ignore[arg-type] more_to_come=more_to_come, ) else: @@ -342,7 +244,7 @@ async def run_operation( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs, + docs=docs, # type: ignore[arg-type] ) return response diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index 2e272a5eab..0396326e7c 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -14,13 +14,13 @@ """The single code path for executing a command over a connection. -Every database operation -- standard commands, cursor ``find``/``getMore`` -operations, and (collection-level and client-level) bulk writes -- runs its -network round trip through :func:`run_command`. The function owns the entire -shared skeleton: command logging, APM event publishing, ``send``/``receive``, -``$clusterTime`` gossip, ``_process_response``, ``_check_command_response``, -failure conversion, and auto-encryption decryption. Callers supply only the -parts that vary (the encoded message and a handful of transport/output hooks). +Every database operation -- standard commands and cursor ``find``/``getMore`` +operations -- runs its network round trip through :func:`run_command`. The +function owns the entire shared skeleton: command logging, APM event +publishing, ``send``/``receive``, ``$clusterTime`` gossip, +``_process_response``, ``_check_command_response``, failure conversion, and +auto-encryption decryption. Callers supply only the parts that vary (the +encoded message and a handful of transport/output hooks). """ from __future__ import annotations @@ -29,6 +29,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Mapping, MutableMapping, Optional, @@ -48,6 +49,7 @@ from bson import CodecOptions from pymongo.message import _OpMsg, _OpReply from pymongo.monitoring import _EventListeners + from pymongo.pool_options import PoolOptions from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection @@ -72,11 +74,24 @@ def run_command( user_fields: Optional[Mapping[str, Any]] = None, orig: Optional[MutableMapping[str, Any]] = None, op_id: Optional[int] = None, + command_name: Optional[str] = None, check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, parse_write_concern_error: bool = False, + pool_opts: Optional[PoolOptions] = None, unacknowledged: bool = False, speculative_hello: bool = False, + ensure_db: bool = False, + use_conn_transport: bool = False, + max_doc_size: int = 0, + more_to_come: bool = False, + set_conn_more_to_come: bool = True, + is_command_response: bool = True, + unpack_res: Optional[Callable[..., Any]] = None, + cursor_id: Optional[int] = None, + reply_doc_builder: Optional[ + Callable[[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]]], _DocumentOut] + ] = None, ) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. @@ -87,10 +102,11 @@ def run_command( the reply when auto-encryption is enabled. :param conn: The Connection to send on. - :param cmd: The command document, used for the ``STARTED`` log event. + :param cmd: The command document, used for the ``STARTED`` log/APM event. :param dbname: The database the command runs against. - :param request_id: The request id of the encoded message. - :param msg: The encoded OP_MSG bytes to send. + :param request_id: The request id of the encoded message (``0`` when + ``more_to_come`` and no message is sent). + :param msg: The encoded bytes to send (ignored when ``more_to_come``). :param client: The MongoClient, for ``$clusterTime`` gossip, logging, and decryption. ``None`` disables those steps (e.g. during handshake). :param session: The session to update from the response. @@ -103,15 +119,40 @@ def run_command( defaults to ``cmd`` (differs only when the wire command was mutated, e.g. with a read preference or after encryption). :param op_id: The APM operation id; defaults to ``request_id``. + :param command_name: The command name for the ``SUCCEEDED``/``FAILED`` APM + events; defaults to the first key of ``cmd``. :param check: Raise OperationFailure on a command error. :param allowable_errors: Errors to ignore when ``check`` is True. :param parse_write_concern_error: Parse the ``writeConcernError`` field. + :param pool_opts: PoolOptions forwarded to ``_check_command_response`` (the + cursor path uses this in place of ``allowable_errors``). :param unacknowledged: True for an unacknowledged write: send only and fake an ``{"ok": 1}`` reply. :param speculative_hello: True if the command carried speculative auth, for APM redaction. + :param ensure_db: Add ``$db`` to the published command if missing (cursor + path), after the ``STARTED`` log has been emitted. + :param use_conn_transport: Send/receive via ``conn.send_message`` / + ``conn.receive_message`` (cursor path) instead of the raw + ``sendall`` / ``receive_message`` (network path). + :param max_doc_size: The largest document size, for ``conn.send_message``. + :param more_to_come: Receive only, without sending (exhaust ``getMore``). + :param set_conn_more_to_come: Store ``reply.more_to_come`` on ``conn`` (the + network/streaming-monitor path); the cursor path manages exhaust + separately and must leave ``conn.more_to_come`` untouched. + :param is_command_response: True if the reply is an OP_MSG command response + (``_process_response``/``_check_command_response``/decryption apply); + False for a legacy OP_QUERY cursor response. + :param unpack_res: A callable decoding the wire response (cursor path); when + ``None`` the reply's own ``unpack_response`` is used. + :param cursor_id: The cursor id passed to ``unpack_res``. + :param reply_doc_builder: Builds the reply document published in the + ``SUCCEEDED`` event from ``(docs, reply)`` (cursor find/getMore format); + when ``None`` the first decoded document is published. """ name = next(iter(cmd)) + if command_name is None: + command_name = name if orig is None: orig = cmd publish = listeners is not None and listeners.enabled_for_commands @@ -135,6 +176,8 @@ def run_command( if publish: assert listeners is not None assert address is not None + if ensure_db and "$db" not in orig: + orig["$db"] = dbname listeners.publish_command_start( orig, dbname, @@ -145,30 +188,51 @@ def run_command( service_id=conn.service_id, ) + reply: Optional[Union[_OpReply, _OpMsg]] try: - sendall(conn.conn.get_conn, msg) - if unacknowledged: + if more_to_come: + reply = conn.receive_message(None) + elif use_conn_transport: + conn.send_message(msg, max_doc_size) + reply = conn.receive_message(request_id) + elif unacknowledged: + sendall(conn.conn.get_conn, msg) # Unacknowledged, fake a successful command response. reply = None docs: list[dict[str, Any]] = [{"ok": 1}] else: + sendall(conn.conn.get_conn, msg) reply = receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) - response_doc = docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, + + if reply is not None: + if set_conn_more_to_come: + conn.more_to_come = reply.more_to_come + if unpack_res is not None: + docs = unpack_res( + reply, + cursor_id, + codec_options, + legacy_response=not is_command_response, + user_fields=user_fields, ) + else: + docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + if is_command_response: + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + pool_opts=pool_opts, + ) except Exception as exc: duration = datetime.datetime.now() - start if isinstance(exc, (NotPrimaryError, OperationFailure)): @@ -199,7 +263,7 @@ def run_command( listeners.publish_command_failure( duration, failure, - name, + command_name, request_id, address, conn.server_connection_id, @@ -210,14 +274,18 @@ def run_command( raise duration = datetime.datetime.now() - start - response_doc = docs[0] + published_reply: _DocumentOut + if reply_doc_builder is not None: + published_reply = reply_doc_builder(docs, reply) + else: + published_reply = docs[0] if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, message=_CommandStatusMessage.SUCCEEDED, clientId=client._topology_settings._topology_id, durationMS=duration, - reply=response_doc, + reply=published_reply, commandName=name, databaseName=dbname, requestId=request_id, @@ -234,8 +302,8 @@ def run_command( assert address is not None listeners.publish_command_success( duration, - response_doc, - name, + published_reply, + command_name, request_id, address, conn.server_connection_id, @@ -245,7 +313,7 @@ def run_command( database_name=dbname, ) - if client and client._encrypter and reply: + if client and client._encrypter and reply and is_command_response: decrypted = client._encrypter.decrypt(reply.raw_command_response()) docs = cast( "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index f57420918b..2e1e0d1b4f 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -26,18 +26,14 @@ Union, ) -from bson import _decode_all_selective -from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers_shared import _check_command_response from pymongo.logger import ( - _COMMAND_LOGGER, _SDAM_LOGGER, - _CommandStatusMessage, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query +from pymongo.message import _GetMore, _OpMsg, _OpReply, _Query from pymongo.response import PinnedResponse, Response +from pymongo.synchronous.command_runner import run_command from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: @@ -158,7 +154,6 @@ def run_operation( :param client: A MongoClient instance. """ assert listeners is not None - publish = listeners.enabled_for_commands start = datetime.now() use_cmd = operation.use_command(conn) @@ -166,151 +161,58 @@ def run_operation( cmd, dbn = self.operation_to_command(operation, conn, use_cmd) if more_to_come: request_id = 0 + data = b"" + max_doc_size = 0 else: message = operation.get_message(read_preference, conn, use_cmd) request_id, data, max_doc_size = self._split_message(message) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) + user_fields = _CURSOR_DOC_FIELDS if use_cmd else None - if publish: - if "$db" not in cmd: - cmd["$db"] = dbn - assert listeners is not None - listeners.publish_command_start( - cmd, - dbn, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - if more_to_come: - reply = conn.receive_message(None) - else: - conn.send_message(data, max_doc_size) - reply = conn.receive_message(request_id) - - # Unpack and check for command errors. - if use_cmd: - user_fields = _CURSOR_DOC_FIELDS - legacy_response = False - else: - user_fields = None - legacy_response = True - docs = unpack_res( - reply, - operation.cursor_id, - operation.codec_options, - legacy_response=legacy_response, - user_fields=user_fields, - ) + def _build_reply_doc( + docs: list[dict[str, Any]], reply: Optional[Union[_OpReply, _OpMsg]] + ) -> _DocumentOut: + # Must publish in find / getMore / explain command response format. if use_cmd: - first = docs[0] - operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] - _check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type] - except Exception as exc: - duration = datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - listeners.publish_command_failure( - duration, - failure, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - raise - duration = datetime.now() - start - # Must publish in find / getMore / explain command response - # format. - if use_cmd: - res = docs[0] - elif operation.name == "explain": - res = docs[0] if docs else {} - else: - res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr] + return docs[0] + elif operation.name == "explain": + return docs[0] if docs else {} + res: dict[str, Any] = { + "cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, # type: ignore[union-attr] + "ok": 1, + } if operation.name == "find": res["cursor"]["firstBatch"] = docs else: res["cursor"]["nextBatch"] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=res, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - listeners.publish_command_success( - duration, - res, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - - # Decrypt response. - client = operation.client # type: ignore[assignment] - if client and client._encrypter: - if use_cmd: - decrypted = client._encrypter.decrypt(reply.raw_command_response()) - docs = _decode_all_selective(decrypted, operation.codec_options, user_fields) + return res + + docs, reply, duration = run_command( + conn, + cmd, + dbn, + request_id, + data, + client=client, + session=operation.session, # type: ignore[arg-type] + listeners=listeners, + address=conn.address, + start=start, + codec_options=operation.codec_options, + user_fields=user_fields, + command_name=operation.name, + pool_opts=conn.opts, + ensure_db=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + more_to_come=bool(more_to_come), + set_conn_more_to_come=False, + is_command_response=use_cmd, + unpack_res=unpack_res, + cursor_id=operation.cursor_id, + reply_doc_builder=_build_reply_doc, + ) + assert reply is not None response: Response @@ -332,7 +234,7 @@ def run_operation( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs, + docs=docs, # type: ignore[arg-type] more_to_come=more_to_come, ) else: @@ -342,7 +244,7 @@ def run_operation( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs, + docs=docs, # type: ignore[arg-type] ) return response From 3814c1760d76661de328caba2ecb8afdd75008d2 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 4 Jun 2026 19:05:35 -0400 Subject: [PATCH 3/7] PYTHON-5676 Route collection bulk writes through run_command Add process_response and decrypt_reply flags plus the conn.unack_write transport to run_command, then route bulk.write_command (acknowledged) and bulk.unack_write through it. The bulk paths pass process_response=False (they run _process_response at the call site, preserving check -> APM-succeed -> process ordering) and decrypt_reply=False (their commands are encrypted up front). The unack path publishes a copy of the command carrying the docs field while logging the bare command, matching the prior asymmetry. Drops the duplicated logging/APM/failure-conversion blocks (and the unreachable _convert_write_result-on-failure branch for unacknowledged writes). --- pymongo/asynchronous/bulk.py | 200 +++++++------------------ pymongo/asynchronous/command_runner.py | 27 +++- pymongo/synchronous/bulk.py | 200 +++++++------------------ pymongo/synchronous/command_runner.py | 27 +++- 4 files changed, 140 insertions(+), 314 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 4a54f9eb3f..f331eb5707 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -19,8 +19,6 @@ from __future__ import annotations import copy -import datetime -import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -37,6 +35,7 @@ from bson.raw_bson import RawBSONDocument from pymongo import _csot, common from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern +from pymongo.asynchronous.command_runner import run_command from pymongo.asynchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( _COMMANDS, @@ -57,14 +56,11 @@ OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _DELETE, _INSERT, _UPDATE, _BulkWriteContext, - _convert_exception, - _convert_write_result, _EncryptedBulkWriteContext, _randint, ) @@ -250,81 +246,36 @@ async def write_command( docs: list[Mapping[str, Any]], client: AsyncMongoClient[Any], ) -> dict[str, Any]: - """A proxy for SocketInfo.write_command that handles event publishing.""" + """Run a batch write command, returning the response as a dict.""" cmd[bwc.field] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._start(cmd, request_id, docs) try: - reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] - duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + result_docs, _, _ = await run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + use_conn_transport=True, + process_response=False, + decrypt_reply=False, + ) + reply = result_docs[0] + # Process the response from the server. await client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if bwc.publish: - bwc._fail(request_id, failure, duration) # Process the response from the server. if isinstance(exc, (NotPrimaryError, OperationFailure)): await client._process_response(exc.details, bwc.session) # type: ignore[arg-type] raise - return reply # type: ignore[return-value] + return reply async def unack_write( self, @@ -336,83 +287,34 @@ async def unack_write( docs: list[Mapping[str, Any]], client: AsyncMongoClient[Any], ) -> Optional[Mapping[str, Any]]: - """A proxy for AsyncConnection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - cmd = bwc._start(cmd, request_id, docs) - try: - result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] - duration = datetime.datetime.now() - bwc.start_time - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) - raise - return result # type: ignore[return-value] + """Send an unacknowledged batch write command.""" + # Historically the STARTED log omits the documents while the published + # CommandStartedEvent includes them, so log ``cmd`` but publish a copy + # carrying the ``docs`` field. + published = dict(cmd) + published[bwc.field] = docs + await run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + orig=published, + unacknowledged=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + process_response=False, + decrypt_reply=False, + ) + return None async def _execute_batch_unack( self, diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index 62529f42bc..94501133a9 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -82,6 +82,8 @@ async def run_command( unacknowledged: bool = False, speculative_hello: bool = False, ensure_db: bool = False, + process_response: bool = True, + decrypt_reply: bool = True, use_conn_transport: bool = False, max_doc_size: int = 0, more_to_come: bool = False, @@ -132,9 +134,15 @@ async def run_command( APM redaction. :param ensure_db: Add ``$db`` to the published command if missing (cursor path), after the ``STARTED`` log has been emitted. + :param process_response: Run ``client._process_response`` on success here; + the bulk paths pass False and process the reply at the call site to + keep their check -> APM-succeed -> process ordering. + :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; + the bulk paths pass False (their commands are encrypted up front). :param use_conn_transport: Send/receive via ``conn.send_message`` / - ``conn.receive_message`` (cursor path) instead of the raw - ``async_sendall`` / ``async_receive_message`` (network path). + ``conn.receive_message`` (cursor path) or ``conn.unack_write`` (bulk + unacknowledged) instead of the raw ``async_sendall`` / + ``async_receive_message`` (network path). :param max_doc_size: The largest document size, for ``conn.send_message``. :param more_to_come: Receive only, without sending (exhaust ``getMore``). :param set_conn_more_to_come: Store ``reply.more_to_come`` on ``conn`` (the @@ -192,14 +200,17 @@ async def run_command( try: if more_to_come: reply = await conn.receive_message(None) - elif use_conn_transport: - await conn.send_message(msg, max_doc_size) - reply = await conn.receive_message(request_id) elif unacknowledged: - await async_sendall(conn.conn.get_conn, msg) + if use_conn_transport: + await conn.unack_write(msg, max_doc_size) + else: + await async_sendall(conn.conn.get_conn, msg) # Unacknowledged, fake a successful command response. reply = None docs: list[dict[str, Any]] = [{"ok": 1}] + elif use_conn_transport: + await conn.send_message(msg, max_doc_size) + reply = await conn.receive_message(request_id) else: await async_sendall(conn.conn.get_conn, msg) reply = await async_receive_message(conn, request_id) @@ -223,7 +234,7 @@ async def run_command( cluster_time = response_doc.get("$clusterTime") if cluster_time: conn._cluster_time = cluster_time - if client: + if process_response and client: await client._process_response(response_doc, session) if check: helpers_shared._check_command_response( @@ -313,7 +324,7 @@ async def run_command( database_name=dbname, ) - if client and client._encrypter and reply and is_command_response: + if client and client._encrypter and reply and is_command_response and decrypt_reply: decrypted = await client._encrypter.decrypt(reply.raw_command_response()) docs = cast( "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 22d6a7a76a..c12f07e139 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -19,8 +19,6 @@ from __future__ import annotations import copy -import datetime -import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -55,19 +53,17 @@ OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _DELETE, _INSERT, _UPDATE, _BulkWriteContext, - _convert_exception, - _convert_write_result, _EncryptedBulkWriteContext, _randint, ) from pymongo.read_preferences import ReadPreference from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern +from pymongo.synchronous.command_runner import run_command from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern @@ -250,81 +246,36 @@ def write_command( docs: list[Mapping[str, Any]], client: MongoClient[Any], ) -> dict[str, Any]: - """A proxy for SocketInfo.write_command that handles event publishing.""" + """Run a batch write command, returning the response as a dict.""" cmd[bwc.field] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._start(cmd, request_id, docs) try: - reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] - duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + result_docs, _, _ = run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + use_conn_transport=True, + process_response=False, + decrypt_reply=False, + ) + reply = result_docs[0] + # Process the response from the server. client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if bwc.publish: - bwc._fail(request_id, failure, duration) # Process the response from the server. if isinstance(exc, (NotPrimaryError, OperationFailure)): client._process_response(exc.details, bwc.session) # type: ignore[arg-type] raise - return reply # type: ignore[return-value] + return reply def unack_write( self, @@ -336,83 +287,34 @@ def unack_write( docs: list[Mapping[str, Any]], client: MongoClient[Any], ) -> Optional[Mapping[str, Any]]: - """A proxy for Connection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - cmd = bwc._start(cmd, request_id, docs) - try: - result = bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] - duration = datetime.datetime.now() - bwc.start_time - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) - raise - return result # type: ignore[return-value] + """Send an unacknowledged batch write command.""" + # Historically the STARTED log omits the documents while the published + # CommandStartedEvent includes them, so log ``cmd`` but publish a copy + # carrying the ``docs`` field. + published = dict(cmd) + published[bwc.field] = docs + run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + orig=published, + unacknowledged=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + process_response=False, + decrypt_reply=False, + ) + return None def _execute_batch_unack( self, diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index 0396326e7c..fc3c9c5c59 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -82,6 +82,8 @@ def run_command( unacknowledged: bool = False, speculative_hello: bool = False, ensure_db: bool = False, + process_response: bool = True, + decrypt_reply: bool = True, use_conn_transport: bool = False, max_doc_size: int = 0, more_to_come: bool = False, @@ -132,9 +134,15 @@ def run_command( APM redaction. :param ensure_db: Add ``$db`` to the published command if missing (cursor path), after the ``STARTED`` log has been emitted. + :param process_response: Run ``client._process_response`` on success here; + the bulk paths pass False and process the reply at the call site to + keep their check -> APM-succeed -> process ordering. + :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; + the bulk paths pass False (their commands are encrypted up front). :param use_conn_transport: Send/receive via ``conn.send_message`` / - ``conn.receive_message`` (cursor path) instead of the raw - ``sendall`` / ``receive_message`` (network path). + ``conn.receive_message`` (cursor path) or ``conn.unack_write`` (bulk + unacknowledged) instead of the raw ``sendall`` / + ``receive_message`` (network path). :param max_doc_size: The largest document size, for ``conn.send_message``. :param more_to_come: Receive only, without sending (exhaust ``getMore``). :param set_conn_more_to_come: Store ``reply.more_to_come`` on ``conn`` (the @@ -192,14 +200,17 @@ def run_command( try: if more_to_come: reply = conn.receive_message(None) - elif use_conn_transport: - conn.send_message(msg, max_doc_size) - reply = conn.receive_message(request_id) elif unacknowledged: - sendall(conn.conn.get_conn, msg) + if use_conn_transport: + conn.unack_write(msg, max_doc_size) + else: + sendall(conn.conn.get_conn, msg) # Unacknowledged, fake a successful command response. reply = None docs: list[dict[str, Any]] = [{"ok": 1}] + elif use_conn_transport: + conn.send_message(msg, max_doc_size) + reply = conn.receive_message(request_id) else: sendall(conn.conn.get_conn, msg) reply = receive_message(conn, request_id) @@ -223,7 +234,7 @@ def run_command( cluster_time = response_doc.get("$clusterTime") if cluster_time: conn._cluster_time = cluster_time - if client: + if process_response and client: client._process_response(response_doc, session) if check: helpers_shared._check_command_response( @@ -313,7 +324,7 @@ def run_command( database_name=dbname, ) - if client and client._encrypter and reply and is_command_response: + if client and client._encrypter and reply and is_command_response and decrypt_reply: decrypted = client._encrypter.decrypt(reply.raw_command_response()) docs = cast( "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) From 7abf82641e645d61eebe202f0681686e4422c5f1 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 4 Jun 2026 19:10:01 -0400 Subject: [PATCH 4/7] PYTHON-5676 Route client-level bulk writes through run_command Route client_bulk.write_command and client_bulk.unack_write through run_command (process_response=False, decrypt_reply=False, conn.unack_write transport for the unack path). The client-level swallow semantics stay at the call site: the except wraps the raised error into reply={"error": exc} and runs the $clusterTime gossip (exc.details for OperationFailure, else {}); the unack path publishes a copy carrying ops/nsInfo while logging the bare command. With this, all command execution -- standard commands, cursor find/getMore, and both bulk write families -- flows through the single run_command path. --- pymongo/asynchronous/client_bulk.py | 194 +++++++--------------------- pymongo/synchronous/client_bulk.py | 194 +++++++--------------------- 2 files changed, 100 insertions(+), 288 deletions(-) diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 015947d7ef..4a74fc0855 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -19,8 +19,6 @@ from __future__ import annotations import copy -import datetime -import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -38,6 +36,7 @@ from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.command_runner import run_command from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.helpers import _handle_reauth @@ -63,12 +62,9 @@ WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _ClientBulkWriteContext, _convert_client_bulk_exception, - _convert_exception, - _convert_write_result, _randint, ) from pymongo.read_preferences import ReadPreference @@ -236,78 +232,32 @@ async def write_command( ns_docs: list[Mapping[str, Any]], client: AsyncMongoClient[Any], ) -> dict[str, Any]: - """A proxy for AsyncConnection.write_command that handles event publishing.""" + """Run a client-level batch write command, returning the response as a dict.""" cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._start(cmd, request_id, op_docs, ns_docs) try: - reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] - duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + result_docs, _, _ = await run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, # type: ignore[arg-type] + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + use_conn_transport=True, + process_response=False, + decrypt_reply=False, + ) + reply = result_docs[0] # Process the response from the server. await self.client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if bwc.publish: - bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} # Process the response from the server. @@ -327,81 +277,37 @@ async def unack_write( ns_docs: list[Mapping[str, Any]], client: AsyncMongoClient[Any], ) -> Optional[Mapping[str, Any]]: - """A proxy for AsyncConnection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - cmd = bwc._start(cmd, request_id, op_docs, ns_docs) + """Send an unacknowledged client-level batch write command.""" + # Historically the STARTED log omits the ops/nsInfo while the published + # CommandStartedEvent includes them, so log ``cmd`` but publish a copy + # carrying those fields. + published = dict(cmd) + published["ops"] = op_docs + published["nsInfo"] = ns_docs + reply: Mapping[str, Any] = {"ok": 1} try: - result = await bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override] - duration = datetime.datetime.now() - bwc.start_time - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) + await run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + orig=published, + unacknowledged=True, + use_conn_transport=True, + max_doc_size=bwc.max_bson_size, + process_response=False, + decrypt_reply=False, + ) except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} return reply diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 1134594ae9..32d0fcc391 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -19,8 +19,6 @@ from __future__ import annotations import copy -import datetime -import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -38,6 +36,7 @@ from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.command_runner import run_command from pymongo.synchronous.database import Database from pymongo.synchronous.helpers import _handle_reauth @@ -63,12 +62,9 @@ WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _ClientBulkWriteContext, _convert_client_bulk_exception, - _convert_exception, - _convert_write_result, _randint, ) from pymongo.read_preferences import ReadPreference @@ -236,78 +232,32 @@ def write_command( ns_docs: list[Mapping[str, Any]], client: MongoClient[Any], ) -> dict[str, Any]: - """A proxy for Connection.write_command that handles event publishing.""" + """Run a client-level batch write command, returning the response as a dict.""" cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._start(cmd, request_id, op_docs, ns_docs) try: - reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] - duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + result_docs, _, _ = run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, # type: ignore[arg-type] + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + use_conn_transport=True, + process_response=False, + decrypt_reply=False, + ) + reply = result_docs[0] # Process the response from the server. self.client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if bwc.publish: - bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} # Process the response from the server. @@ -327,81 +277,37 @@ def unack_write( ns_docs: list[Mapping[str, Any]], client: MongoClient[Any], ) -> Optional[Mapping[str, Any]]: - """A proxy for Connection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - cmd = bwc._start(cmd, request_id, op_docs, ns_docs) + """Send an unacknowledged client-level batch write command.""" + # Historically the STARTED log omits the ops/nsInfo while the published + # CommandStartedEvent includes them, so log ``cmd`` but publish a copy + # carrying those fields. + published = dict(cmd) + published["ops"] = op_docs + published["nsInfo"] = ns_docs + reply: Mapping[str, Any] = {"ok": 1} try: - result = bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override] - duration = datetime.datetime.now() - bwc.start_time - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) + run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + orig=published, + unacknowledged=True, + use_conn_transport=True, + max_doc_size=bwc.max_bson_size, + process_response=False, + decrypt_reply=False, + ) except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} return reply From c89702cb04ed887ab1c203ae04306962905a0122 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 4 Jun 2026 19:31:24 -0400 Subject: [PATCH 5/7] PYTHON-5676 Rename network.py to command_encoder.py After the consolidation, this module no longer does any networking -- the send/receive round trip moved into command_runner.run_command. It now only encodes a command and runs its pre-flight (read preference/concern, collation, $clusterTime, auto-encryption, CSOT, OP_MSG encoding), so 'network' was misleading and collided with the lower-level network_layer.py (raw sockets). Pure rename: git mv the async module (synchro regenerates the sync mirror) and update the two pool.py imports. No behavior change. --- pymongo/asynchronous/{network.py => command_encoder.py} | 9 ++++++++- pymongo/asynchronous/pool.py | 2 +- pymongo/synchronous/{network.py => command_encoder.py} | 9 ++++++++- pymongo/synchronous/pool.py | 2 +- 4 files changed, 18 insertions(+), 4 deletions(-) rename pymongo/asynchronous/{network.py => command_encoder.py} (94%) rename pymongo/synchronous/{network.py => command_encoder.py} (94%) diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/command_encoder.py similarity index 94% rename from pymongo/asynchronous/network.py rename to pymongo/asynchronous/command_encoder.py index ed86b4522f..336c1dcb50 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/command_encoder.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Internal network layer helper methods.""" +"""Encode a command and run it over a connection. + +This builds the wire-protocol message for a single command -- applying read +preference, read concern, collation, ``$clusterTime``, auto-encryption, CSOT, +and OP_MSG encoding -- then hands it to +:func:`pymongo.asynchronous.command_runner.run_command` for the network round +trip. The raw socket I/O lives in :mod:`pymongo.network_layer`. +""" from __future__ import annotations import datetime diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index a5d5b28990..9df45c494f 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -39,8 +39,8 @@ from bson import DEFAULT_CODEC_OPTIONS from pymongo import _csot, helpers_shared from pymongo.asynchronous.client_session import _validate_session_write_concern +from pymongo.asynchronous.command_encoder import command from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.asynchronous.network import command from pymongo.common import ( MAX_BSON_SIZE, MAX_MESSAGE_SIZE, diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/command_encoder.py similarity index 94% rename from pymongo/synchronous/network.py rename to pymongo/synchronous/command_encoder.py index 6576f1c5e6..5821e0c652 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/command_encoder.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Internal network layer helper methods.""" +"""Encode a command and run it over a connection. + +This builds the wire-protocol message for a single command -- applying read +preference, read concern, collation, ``$clusterTime``, auto-encryption, CSOT, +and OP_MSG encoding -- then hands it to +:func:`pymongo.command_runner.run_command` for the network round +trip. The raw socket I/O lives in :mod:`pymongo.network_layer`. +""" from __future__ import annotations import datetime diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 25f2d08fe7..5980341b11 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -88,8 +88,8 @@ from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker from pymongo.synchronous.client_session import _validate_session_write_concern +from pymongo.synchronous.command_encoder import command from pymongo.synchronous.helpers import _handle_reauth -from pymongo.synchronous.network import command if TYPE_CHECKING: from bson import CodecOptions From 3ee7e79c5332b3dec07b962254a127f493368152 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Fri, 5 Jun 2026 15:12:29 -0400 Subject: [PATCH 6/7] Noah feedback --- pymongo/asynchronous/bulk.py | 9 +- pymongo/asynchronous/client_bulk.py | 9 +- pymongo/asynchronous/command_encoder.py | 59 ++++--- pymongo/asynchronous/command_runner.py | 225 ++++++++++++++++++++++-- pymongo/asynchronous/pool.py | 18 -- pymongo/asynchronous/server.py | 7 +- pymongo/message.py | 102 ----------- pymongo/synchronous/bulk.py | 9 +- pymongo/synchronous/client_bulk.py | 9 +- pymongo/synchronous/command_encoder.py | 59 ++++--- pymongo/synchronous/command_runner.py | 225 ++++++++++++++++++++++-- pymongo/synchronous/pool.py | 18 -- pymongo/synchronous/server.py | 7 +- 13 files changed, 520 insertions(+), 236 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index f331eb5707..7bf05f5526 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -35,7 +35,7 @@ from bson.raw_bson import RawBSONDocument from pymongo import _csot, common from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern -from pymongo.asynchronous.command_runner import run_command +from pymongo.asynchronous.command_runner import run_command, run_unacknowledged_command from pymongo.asynchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( _COMMANDS, @@ -293,7 +293,7 @@ async def unack_write( # carrying the ``docs`` field. published = dict(cmd) published[bwc.field] = docs - await run_command( + await run_unacknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -308,11 +308,8 @@ async def unack_write( op_id=bwc.op_id, command_name=bwc.name, orig=published, - unacknowledged=True, use_conn_transport=True, max_doc_size=max_doc_size, - process_response=False, - decrypt_reply=False, ) return None @@ -386,7 +383,7 @@ async def _execute_command( run = self.current_run # AsyncConnection.command validates the session, but we use - # AsyncConnection.write_command + # run_command/run_unacknowledged_command. conn.validate_session(client, session) last_run = False diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 4a74fc0855..f742ff7c77 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -36,7 +36,7 @@ from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.command_runner import run_command +from pymongo.asynchronous.command_runner import run_command, run_unacknowledged_command from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.helpers import _handle_reauth @@ -286,7 +286,7 @@ async def unack_write( published["nsInfo"] = ns_docs reply: Mapping[str, Any] = {"ok": 1} try: - await run_command( + await run_unacknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -301,11 +301,8 @@ async def unack_write( op_id=bwc.op_id, command_name=bwc.name, orig=published, - unacknowledged=True, use_conn_transport=True, max_doc_size=bwc.max_bson_size, - process_response=False, - decrypt_reply=False, ) except Exception as exc: # Top-level error will be embedded in ClientBulkWriteException. @@ -406,7 +403,7 @@ async def _execute_command( listeners = self.client._event_listeners # AsyncConnection.command validates the session, but we use - # AsyncConnection.write_command + # run_command/run_unacknowledged_command. conn.validate_session(self.client, session) bwc = self.bulk_ctx_class( diff --git a/pymongo/asynchronous/command_encoder.py b/pymongo/asynchronous/command_encoder.py index 336c1dcb50..ff2ae12806 100644 --- a/pymongo/asynchronous/command_encoder.py +++ b/pymongo/asynchronous/command_encoder.py @@ -34,7 +34,7 @@ ) from pymongo import _csot, message -from pymongo.asynchronous.command_runner import run_command +from pymongo.asynchronous.command_runner import run_command, run_unacknowledged_command from pymongo.compression_support import _NO_COMPRESSION from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate @@ -155,24 +155,41 @@ async def command( if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - docs, _, _ = await run_command( - conn, - spec, - dbname, - request_id, - msg, - client=client, - session=session, - listeners=listeners, - address=address, - start=start, - codec_options=codec_options, - user_fields=user_fields, - orig=orig, - check=check, - allowable_errors=allowable_errors, - parse_write_concern_error=parse_write_concern_error, - unacknowledged=use_op_msg and unacknowledged, - speculative_hello=speculative_hello, - ) + if use_op_msg and unacknowledged: + docs, _, _ = await run_unacknowledged_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + speculative_hello=speculative_hello, + ) + else: + docs, _, _ = await run_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + speculative_hello=speculative_hello, + ) return docs[0] # type: ignore[return-value] diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index 94501133a9..df6ea0d646 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""The single code path for executing a command over a connection. - -Every database operation -- standard commands and cursor ``find``/``getMore`` -operations -- runs its network round trip through :func:`run_command`. The -function owns the entire shared skeleton: command logging, APM event -publishing, ``send``/``receive``, ``$clusterTime`` gossip, -``_process_response``, ``_check_command_response``, failure conversion, and -auto-encryption decryption. Callers supply only the parts that vary (the -encoded message and a handful of transport/output hooks). +"""The shared code path for executing a command over a connection. + +Every database operation runs its network round trip through one of three +public entry points -- :func:`run_command` (acknowledged commands and bulk +write batches), :func:`run_unacknowledged_command` (unacknowledged writes), and +:func:`run_cursor_command` (cursor ``find``/``getMore`` operations) -- each of +which wraps the private :func:`_run_command`. ``_run_command`` owns the entire +shared skeleton: command logging, APM event publishing, ``send``/``receive``, +``$clusterTime`` gossip, ``_process_response``, ``_check_command_response``, +failure conversion, and auto-encryption decryption. The three wrappers fix the +transport and response-shaping flags for their command type so call sites pass +only the parts that vary (the encoded message and a handful of hooks). """ from __future__ import annotations @@ -58,7 +61,7 @@ _IS_SYNC = False -async def run_command( +async def _run_command( conn: AsyncConnection, cmd: MutableMapping[str, Any], dbname: str, @@ -97,7 +100,13 @@ async def run_command( ) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. - This is the single code path for command execution. It publishes the + This is the shared implementation behind :func:`run_command`, + :func:`run_unacknowledged_command`, and :func:`run_cursor_command`. Those + three public entry points each fix the transport and response-shaping flags + for their command type; the bare kwargs here should not be set directly by + new call sites. + + It publishes the ``STARTED``/``SUCCEEDED``/``FAILED`` command log and APM events, performs the network round trip, gossips ``$clusterTime``, runs ``client._process_response`` and ``_check_command_response``, and decrypts @@ -331,3 +340,197 @@ async def run_command( ) return docs, reply, duration + + +async def run_command( + conn: AsyncConnection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[AsyncMongoClient[Any]], + session: Optional[AsyncClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + command_name: Optional[str] = None, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + speculative_hello: bool = False, + use_conn_transport: bool = False, + process_response: bool = True, + decrypt_reply: bool = True, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send an acknowledged command and return ``(docs, reply, duration)``. + + This is the entry point for standard commands and bulk write batches: it + sends ``msg``, receives the reply, runs ``_process_response`` and + ``_check_command_response``, decrypts the reply when auto-encryption is + enabled, and publishes the command log/APM events. + + :param use_conn_transport: Send/receive via ``conn.send_message`` / + ``conn.receive_message`` (bulk path) instead of the raw + ``async_sendall`` / ``async_receive_message`` (standard command path). + :param process_response: Run ``client._process_response`` here; the bulk + paths pass False and process the reply at the call site to keep their + check -> APM-succeed -> process ordering. + :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; the + bulk paths pass False (their commands are encrypted up front). + + See :func:`_run_command` for the remaining parameters. + """ + return await _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + op_id=op_id, + command_name=command_name, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + speculative_hello=speculative_hello, + use_conn_transport=use_conn_transport, + process_response=process_response, + decrypt_reply=decrypt_reply, + ) + + +async def run_unacknowledged_command( + conn: AsyncConnection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[AsyncMongoClient[Any]], + session: Optional[AsyncClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + command_name: Optional[str] = None, + speculative_hello: bool = False, + use_conn_transport: bool = False, + max_doc_size: int = 0, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send an unacknowledged command and fake an ``{"ok": 1}`` reply. + + The message is sent only -- no reply is received -- so the response + processing, command checking, and decryption steps are skipped. + + :param use_conn_transport: Send via ``conn.unack_write`` (bulk path) instead + of the raw ``async_sendall`` (standard command path). + :param max_doc_size: The largest document size, for ``conn.unack_write``. + + See :func:`_run_command` for the remaining parameters. + """ + return await _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + op_id=op_id, + command_name=command_name, + speculative_hello=speculative_hello, + unacknowledged=True, + use_conn_transport=use_conn_transport, + max_doc_size=max_doc_size, + process_response=False, + decrypt_reply=False, + ) + + +async def run_cursor_command( + conn: AsyncConnection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[AsyncMongoClient[Any]], + session: Optional[AsyncClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + command_name: str, + user_fields: Optional[Mapping[str, Any]] = None, + pool_opts: Optional[PoolOptions] = None, + max_doc_size: int = 0, + more_to_come: bool = False, + is_command_response: bool = True, + unpack_res: Optional[Callable[..., Any]] = None, + cursor_id: Optional[int] = None, + reply_doc_builder: Optional[ + Callable[[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]]], _DocumentOut] + ] = None, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Run a cursor ``find``/``getMore`` operation over ``conn``. + + Uses the connection transport, leaves ``conn.more_to_come`` untouched (the + cursor path manages exhaust separately), and shapes the published reply in + the find/getMore command response format. + + :param more_to_come: Receive only, without sending (exhaust ``getMore``). + :param is_command_response: True for an OP_MSG command response; False for a + legacy OP_QUERY cursor response. + :param unpack_res: A callable decoding the wire response. + :param cursor_id: The cursor id passed to ``unpack_res``. + :param reply_doc_builder: Builds the reply document published in the + ``SUCCEEDED`` event from ``(docs, reply)``. + + See :func:`_run_command` for the remaining parameters. + """ + return await _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + command_name=command_name, + pool_opts=pool_opts, + ensure_db=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + more_to_come=more_to_come, + set_conn_more_to_come=False, + is_command_response=is_command_response, + unpack_res=unpack_res, + cursor_id=cursor_id, + reply_doc_builder=reply_doc_builder, + ) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 9df45c494f..85968979d9 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -477,24 +477,6 @@ async def unack_write(self, msg: bytes, max_doc_size: int) -> None: self._raise_if_not_writable(True) await self.send_message(msg, max_doc_size) - async def write_command( - self, request_id: int, msg: bytes, codec_options: CodecOptions[Mapping[str, Any]] - ) -> dict[str, Any]: - """Send "insert" etc. command, returning response as a dict. - - Can raise ConnectionFailure or OperationFailure. - - :param request_id: an int. - :param msg: bytes, the command message. - """ - await self.send_message(msg, 0) - reply = await self.receive_message(request_id) - result = reply.command_response(codec_options) - - # Raises NotPrimaryError or OperationFailure. - helpers_shared._check_command_response(result, self.max_wire_version) - return result - async def authenticate(self, reauthenticate: bool = False) -> None: """Authenticate to the server if needed. diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index b18cf56c52..0c4fbed00f 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -26,7 +26,7 @@ Union, ) -from pymongo.asynchronous.command_runner import run_command +from pymongo.asynchronous.command_runner import run_cursor_command from pymongo.asynchronous.helpers import _handle_reauth from pymongo.logger import ( _SDAM_LOGGER, @@ -187,7 +187,7 @@ def _build_reply_doc( res["cursor"]["nextBatch"] = docs return res - docs, reply, duration = await run_command( + docs, reply, duration = await run_cursor_command( conn, cmd, dbn, @@ -202,11 +202,8 @@ def _build_reply_doc( user_fields=user_fields, command_name=operation.name, pool_opts=conn.opts, - ensure_db=True, - use_conn_transport=True, max_doc_size=max_doc_size, more_to_come=bool(more_to_come), - set_conn_more_to_come=False, is_command_response=use_cmd, unpack_res=unpack_res, cursor_id=operation.cursor_id, diff --git a/pymongo/message.py b/pymongo/message.py index b0d1ceb105..fd59be192a 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -75,7 +75,6 @@ _AgnosticClientSession, _AgnosticConnection, _AgnosticMongoClient, - _DocumentOut, ) @@ -152,42 +151,6 @@ def _convert_client_bulk_exception(exception: Exception) -> dict[str, Any]: } -def _convert_write_result( - operation: str, command: Mapping[str, Any], result: Mapping[str, Any] -) -> dict[str, Any]: - """Convert a legacy write result to write command format.""" - # Based on _merge_legacy from bulk.py - affected = result.get("n", 0) - res = {"ok": 1, "n": affected} - errmsg = result.get("errmsg", result.get("err", "")) - if errmsg: - # The write was successful on at least the primary so don't return. - if result.get("wtimeout"): - res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} - else: - # The write failed. - error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} - if "errInfo" in result: - error["errInfo"] = result["errInfo"] - res["writeErrors"] = [error] - return res - if operation == "insert": - # GLE result for insert is always 0 in most MongoDB versions. - res["n"] = len(command["documents"]) - elif operation == "update": - if "upserted" in result: - res["upserted"] = [{"index": 0, "_id": result["upserted"]}] - # Versions of MongoDB before 2.6 don't return the _id for an - # upsert if _id is not an ObjectId. - elif result.get("updatedExisting") is False and affected == 1: - # If _id is in both the update document *and* the query spec - # the update document _id takes precedence. - update = command["updates"][0] - _id = update["u"].get("_id", update["q"].get("_id")) - res["upserted"] = [{"index": 0, "_id": _id}] - return res - - _OPTIONS = { "tailable": 2, "oplogReplay": 8, @@ -636,34 +599,6 @@ def max_split_size(self) -> int: """The maximum size of a BSON command before batch splitting.""" return self.max_bson_size - def _succeed(self, request_id: int, reply: _DocumentOut, duration: datetime.timedelta) -> None: - """Publish a CommandSucceededEvent.""" - self.listeners.publish_command_success( - duration, - reply, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - - def _fail(self, request_id: int, failure: _DocumentOut, duration: datetime.timedelta) -> None: - """Publish a CommandFailedEvent.""" - self.listeners.publish_command_failure( - duration, - failure, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - class _BulkWriteContext(_BulkWriteContextBase): """A wrapper around AsyncConnection/Connection for use with the collection-level bulk write API.""" @@ -703,22 +638,6 @@ def batch_command( raise InvalidOperation("cannot do an empty bulk write") return request_id, msg, to_send - def _start( - self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] - ) -> MutableMapping[str, Any]: - """Publish a CommandStartedEvent.""" - cmd[self.field] = docs - self.listeners.publish_command_start( - cmd, - self.db_name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - ) - return cmd - class _EncryptedBulkWriteContext(_BulkWriteContext): __slots__ = () @@ -965,27 +884,6 @@ def batch_command( raise InvalidOperation("cannot do an empty bulk write") return request_id, msg, to_send_ops, to_send_ns - def _start( - self, - cmd: MutableMapping[str, Any], - request_id: int, - op_docs: list[Mapping[str, Any]], - ns_docs: list[Mapping[str, Any]], - ) -> MutableMapping[str, Any]: - """Publish a CommandStartedEvent.""" - cmd["ops"] = op_docs - cmd["nsInfo"] = ns_docs - self.listeners.publish_command_start( - cmd, - self.db_name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - ) - return cmd - _OP_MSG_OVERHEAD = 1000 diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index c12f07e139..7527092377 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -63,7 +63,7 @@ ) from pymongo.read_preferences import ReadPreference from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern -from pymongo.synchronous.command_runner import run_command +from pymongo.synchronous.command_runner import run_command, run_unacknowledged_command from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern @@ -293,7 +293,7 @@ def unack_write( # carrying the ``docs`` field. published = dict(cmd) published[bwc.field] = docs - run_command( + run_unacknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -308,11 +308,8 @@ def unack_write( op_id=bwc.op_id, command_name=bwc.name, orig=published, - unacknowledged=True, use_conn_transport=True, max_doc_size=max_doc_size, - process_response=False, - decrypt_reply=False, ) return None @@ -386,7 +383,7 @@ def _execute_command( run = self.current_run # Connection.command validates the session, but we use - # Connection.write_command + # run_command/run_unacknowledged_command. conn.validate_session(client, session) last_run = False diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 32d0fcc391..8ae72d8cfd 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -36,7 +36,7 @@ from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.command_runner import run_command +from pymongo.synchronous.command_runner import run_command, run_unacknowledged_command from pymongo.synchronous.database import Database from pymongo.synchronous.helpers import _handle_reauth @@ -286,7 +286,7 @@ def unack_write( published["nsInfo"] = ns_docs reply: Mapping[str, Any] = {"ok": 1} try: - run_command( + run_unacknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -301,11 +301,8 @@ def unack_write( op_id=bwc.op_id, command_name=bwc.name, orig=published, - unacknowledged=True, use_conn_transport=True, max_doc_size=bwc.max_bson_size, - process_response=False, - decrypt_reply=False, ) except Exception as exc: # Top-level error will be embedded in ClientBulkWriteException. @@ -404,7 +401,7 @@ def _execute_command( listeners = self.client._event_listeners # Connection.command validates the session, but we use - # Connection.write_command + # run_command/run_unacknowledged_command. conn.validate_session(self.client, session) bwc = self.bulk_ctx_class( diff --git a/pymongo/synchronous/command_encoder.py b/pymongo/synchronous/command_encoder.py index 5821e0c652..1c05e922c9 100644 --- a/pymongo/synchronous/command_encoder.py +++ b/pymongo/synchronous/command_encoder.py @@ -37,7 +37,7 @@ from pymongo.compression_support import _NO_COMPRESSION from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate -from pymongo.synchronous.command_runner import run_command +from pymongo.synchronous.command_runner import run_command, run_unacknowledged_command if TYPE_CHECKING: from bson import CodecOptions @@ -155,24 +155,41 @@ def command( if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - docs, _, _ = run_command( - conn, - spec, - dbname, - request_id, - msg, - client=client, - session=session, - listeners=listeners, - address=address, - start=start, - codec_options=codec_options, - user_fields=user_fields, - orig=orig, - check=check, - allowable_errors=allowable_errors, - parse_write_concern_error=parse_write_concern_error, - unacknowledged=use_op_msg and unacknowledged, - speculative_hello=speculative_hello, - ) + if use_op_msg and unacknowledged: + docs, _, _ = run_unacknowledged_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + speculative_hello=speculative_hello, + ) + else: + docs, _, _ = run_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + speculative_hello=speculative_hello, + ) return docs[0] # type: ignore[return-value] diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index fc3c9c5c59..70ee8bab22 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""The single code path for executing a command over a connection. - -Every database operation -- standard commands and cursor ``find``/``getMore`` -operations -- runs its network round trip through :func:`run_command`. The -function owns the entire shared skeleton: command logging, APM event -publishing, ``send``/``receive``, ``$clusterTime`` gossip, -``_process_response``, ``_check_command_response``, failure conversion, and -auto-encryption decryption. Callers supply only the parts that vary (the -encoded message and a handful of transport/output hooks). +"""The shared code path for executing a command over a connection. + +Every database operation runs its network round trip through one of three +public entry points -- :func:`run_command` (acknowledged commands and bulk +write batches), :func:`run_unacknowledged_command` (unacknowledged writes), and +:func:`run_cursor_command` (cursor ``find``/``getMore`` operations) -- each of +which wraps the private :func:`_run_command`. ``_run_command`` owns the entire +shared skeleton: command logging, APM event publishing, ``send``/``receive``, +``$clusterTime`` gossip, ``_process_response``, ``_check_command_response``, +failure conversion, and auto-encryption decryption. The three wrappers fix the +transport and response-shaping flags for their command type so call sites pass +only the parts that vary (the encoded message and a handful of hooks). """ from __future__ import annotations @@ -58,7 +61,7 @@ _IS_SYNC = True -def run_command( +def _run_command( conn: Connection, cmd: MutableMapping[str, Any], dbname: str, @@ -97,7 +100,13 @@ def run_command( ) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. - This is the single code path for command execution. It publishes the + This is the shared implementation behind :func:`run_command`, + :func:`run_unacknowledged_command`, and :func:`run_cursor_command`. Those + three public entry points each fix the transport and response-shaping flags + for their command type; the bare kwargs here should not be set directly by + new call sites. + + It publishes the ``STARTED``/``SUCCEEDED``/``FAILED`` command log and APM events, performs the network round trip, gossips ``$clusterTime``, runs ``client._process_response`` and ``_check_command_response``, and decrypts @@ -331,3 +340,197 @@ def run_command( ) return docs, reply, duration + + +def run_command( + conn: Connection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[MongoClient[Any]], + session: Optional[ClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + command_name: Optional[str] = None, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + speculative_hello: bool = False, + use_conn_transport: bool = False, + process_response: bool = True, + decrypt_reply: bool = True, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send an acknowledged command and return ``(docs, reply, duration)``. + + This is the entry point for standard commands and bulk write batches: it + sends ``msg``, receives the reply, runs ``_process_response`` and + ``_check_command_response``, decrypts the reply when auto-encryption is + enabled, and publishes the command log/APM events. + + :param use_conn_transport: Send/receive via ``conn.send_message`` / + ``conn.receive_message`` (bulk path) instead of the raw + ``sendall`` / ``receive_message`` (standard command path). + :param process_response: Run ``client._process_response`` here; the bulk + paths pass False and process the reply at the call site to keep their + check -> APM-succeed -> process ordering. + :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; the + bulk paths pass False (their commands are encrypted up front). + + See :func:`_run_command` for the remaining parameters. + """ + return _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + op_id=op_id, + command_name=command_name, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + speculative_hello=speculative_hello, + use_conn_transport=use_conn_transport, + process_response=process_response, + decrypt_reply=decrypt_reply, + ) + + +def run_unacknowledged_command( + conn: Connection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[MongoClient[Any]], + session: Optional[ClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + command_name: Optional[str] = None, + speculative_hello: bool = False, + use_conn_transport: bool = False, + max_doc_size: int = 0, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send an unacknowledged command and fake an ``{"ok": 1}`` reply. + + The message is sent only -- no reply is received -- so the response + processing, command checking, and decryption steps are skipped. + + :param use_conn_transport: Send via ``conn.unack_write`` (bulk path) instead + of the raw ``sendall`` (standard command path). + :param max_doc_size: The largest document size, for ``conn.unack_write``. + + See :func:`_run_command` for the remaining parameters. + """ + return _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + op_id=op_id, + command_name=command_name, + speculative_hello=speculative_hello, + unacknowledged=True, + use_conn_transport=use_conn_transport, + max_doc_size=max_doc_size, + process_response=False, + decrypt_reply=False, + ) + + +def run_cursor_command( + conn: Connection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[MongoClient[Any]], + session: Optional[ClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + command_name: str, + user_fields: Optional[Mapping[str, Any]] = None, + pool_opts: Optional[PoolOptions] = None, + max_doc_size: int = 0, + more_to_come: bool = False, + is_command_response: bool = True, + unpack_res: Optional[Callable[..., Any]] = None, + cursor_id: Optional[int] = None, + reply_doc_builder: Optional[ + Callable[[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]]], _DocumentOut] + ] = None, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Run a cursor ``find``/``getMore`` operation over ``conn``. + + Uses the connection transport, leaves ``conn.more_to_come`` untouched (the + cursor path manages exhaust separately), and shapes the published reply in + the find/getMore command response format. + + :param more_to_come: Receive only, without sending (exhaust ``getMore``). + :param is_command_response: True for an OP_MSG command response; False for a + legacy OP_QUERY cursor response. + :param unpack_res: A callable decoding the wire response. + :param cursor_id: The cursor id passed to ``unpack_res``. + :param reply_doc_builder: Builds the reply document published in the + ``SUCCEEDED`` event from ``(docs, reply)``. + + See :func:`_run_command` for the remaining parameters. + """ + return _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + command_name=command_name, + pool_opts=pool_opts, + ensure_db=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + more_to_come=more_to_come, + set_conn_more_to_come=False, + is_command_response=is_command_response, + unpack_res=unpack_res, + cursor_id=cursor_id, + reply_doc_builder=reply_doc_builder, + ) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 5980341b11..8eba007764 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -477,24 +477,6 @@ def unack_write(self, msg: bytes, max_doc_size: int) -> None: self._raise_if_not_writable(True) self.send_message(msg, max_doc_size) - def write_command( - self, request_id: int, msg: bytes, codec_options: CodecOptions[Mapping[str, Any]] - ) -> dict[str, Any]: - """Send "insert" etc. command, returning response as a dict. - - Can raise ConnectionFailure or OperationFailure. - - :param request_id: an int. - :param msg: bytes, the command message. - """ - self.send_message(msg, 0) - reply = self.receive_message(request_id) - result = reply.command_response(codec_options) - - # Raises NotPrimaryError or OperationFailure. - helpers_shared._check_command_response(result, self.max_wire_version) - return result - def authenticate(self, reauthenticate: bool = False) -> None: """Authenticate to the server if needed. diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 2e1e0d1b4f..406f511f33 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -33,7 +33,7 @@ ) from pymongo.message import _GetMore, _OpMsg, _OpReply, _Query from pymongo.response import PinnedResponse, Response -from pymongo.synchronous.command_runner import run_command +from pymongo.synchronous.command_runner import run_cursor_command from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: @@ -187,7 +187,7 @@ def _build_reply_doc( res["cursor"]["nextBatch"] = docs return res - docs, reply, duration = run_command( + docs, reply, duration = run_cursor_command( conn, cmd, dbn, @@ -202,11 +202,8 @@ def _build_reply_doc( user_fields=user_fields, command_name=operation.name, pool_opts=conn.opts, - ensure_db=True, - use_conn_transport=True, max_doc_size=max_doc_size, more_to_come=bool(more_to_come), - set_conn_more_to_come=False, is_command_response=use_cmd, unpack_res=unpack_res, cursor_id=operation.cursor_id, From 2b022c61f80636e18a33ae272ee3c51c733d1950 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Fri, 5 Jun 2026 15:17:56 -0400 Subject: [PATCH 7/7] =?UTF-8?q?rename=20run=5Fcommand=20=E2=86=92=20run=5F?= =?UTF-8?q?acknowledged=5Fcommand?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pymongo/asynchronous/bulk.py | 9 ++++++--- pymongo/asynchronous/client_bulk.py | 9 ++++++--- pymongo/asynchronous/command_encoder.py | 10 +++++++--- pymongo/asynchronous/command_runner.py | 9 +++++---- pymongo/synchronous/bulk.py | 9 ++++++--- pymongo/synchronous/client_bulk.py | 9 ++++++--- pymongo/synchronous/command_encoder.py | 10 +++++++--- pymongo/synchronous/command_runner.py | 9 +++++---- 8 files changed, 48 insertions(+), 26 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 7bf05f5526..f93b9ceb42 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -35,7 +35,10 @@ from bson.raw_bson import RawBSONDocument from pymongo import _csot, common from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern -from pymongo.asynchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.asynchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) from pymongo.asynchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( _COMMANDS, @@ -249,7 +252,7 @@ async def write_command( """Run a batch write command, returning the response as a dict.""" cmd[bwc.field] = docs try: - result_docs, _, _ = await run_command( + result_docs, _, _ = await run_acknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -383,7 +386,7 @@ async def _execute_command( run = self.current_run # AsyncConnection.command validates the session, but we use - # run_command/run_unacknowledged_command. + # run_acknowledged_command/run_unacknowledged_command. conn.validate_session(client, session) last_run = False diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index f742ff7c77..a67da48ede 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -36,7 +36,10 @@ from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.asynchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.helpers import _handle_reauth @@ -236,7 +239,7 @@ async def write_command( cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs try: - result_docs, _, _ = await run_command( + result_docs, _, _ = await run_acknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -403,7 +406,7 @@ async def _execute_command( listeners = self.client._event_listeners # AsyncConnection.command validates the session, but we use - # run_command/run_unacknowledged_command. + # run_acknowledged_command/run_unacknowledged_command. conn.validate_session(self.client, session) bwc = self.bulk_ctx_class( diff --git a/pymongo/asynchronous/command_encoder.py b/pymongo/asynchronous/command_encoder.py index ff2ae12806..10df997072 100644 --- a/pymongo/asynchronous/command_encoder.py +++ b/pymongo/asynchronous/command_encoder.py @@ -17,7 +17,8 @@ This builds the wire-protocol message for a single command -- applying read preference, read concern, collation, ``$clusterTime``, auto-encryption, CSOT, and OP_MSG encoding -- then hands it to -:func:`pymongo.asynchronous.command_runner.run_command` for the network round +:func:`pymongo.asynchronous.command_runner.run_acknowledged_command` for the +network round trip. The raw socket I/O lives in :mod:`pymongo.network_layer`. """ from __future__ import annotations @@ -34,7 +35,10 @@ ) from pymongo import _csot, message -from pymongo.asynchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.asynchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) from pymongo.compression_support import _NO_COMPRESSION from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate @@ -173,7 +177,7 @@ async def command( speculative_hello=speculative_hello, ) else: - docs, _, _ = await run_command( + docs, _, _ = await run_acknowledged_command( conn, spec, dbname, diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index df6ea0d646..f4388d2672 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -15,8 +15,9 @@ """The shared code path for executing a command over a connection. Every database operation runs its network round trip through one of three -public entry points -- :func:`run_command` (acknowledged commands and bulk -write batches), :func:`run_unacknowledged_command` (unacknowledged writes), and +public entry points -- :func:`run_acknowledged_command` (acknowledged commands +and bulk write batches), :func:`run_unacknowledged_command` (unacknowledged +writes), and :func:`run_cursor_command` (cursor ``find``/``getMore`` operations) -- each of which wraps the private :func:`_run_command`. ``_run_command`` owns the entire shared skeleton: command logging, APM event publishing, ``send``/``receive``, @@ -100,7 +101,7 @@ async def _run_command( ) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. - This is the shared implementation behind :func:`run_command`, + This is the shared implementation behind :func:`run_acknowledged_command`, :func:`run_unacknowledged_command`, and :func:`run_cursor_command`. Those three public entry points each fix the transport and response-shaping flags for their command type; the bare kwargs here should not be set directly by @@ -342,7 +343,7 @@ async def _run_command( return docs, reply, duration -async def run_command( +async def run_acknowledged_command( conn: AsyncConnection, cmd: MutableMapping[str, Any], dbname: str, diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 7527092377..4984a20df9 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -63,7 +63,10 @@ ) from pymongo.read_preferences import ReadPreference from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern -from pymongo.synchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.synchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern @@ -249,7 +252,7 @@ def write_command( """Run a batch write command, returning the response as a dict.""" cmd[bwc.field] = docs try: - result_docs, _, _ = run_command( + result_docs, _, _ = run_acknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -383,7 +386,7 @@ def _execute_command( run = self.current_run # Connection.command validates the session, but we use - # run_command/run_unacknowledged_command. + # run_acknowledged_command/run_unacknowledged_command. conn.validate_session(client, session) last_run = False diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 8ae72d8cfd..22564e2c56 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -36,7 +36,10 @@ from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.synchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) from pymongo.synchronous.database import Database from pymongo.synchronous.helpers import _handle_reauth @@ -236,7 +239,7 @@ def write_command( cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs try: - result_docs, _, _ = run_command( + result_docs, _, _ = run_acknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -401,7 +404,7 @@ def _execute_command( listeners = self.client._event_listeners # Connection.command validates the session, but we use - # run_command/run_unacknowledged_command. + # run_acknowledged_command/run_unacknowledged_command. conn.validate_session(self.client, session) bwc = self.bulk_ctx_class( diff --git a/pymongo/synchronous/command_encoder.py b/pymongo/synchronous/command_encoder.py index 1c05e922c9..536dd5069e 100644 --- a/pymongo/synchronous/command_encoder.py +++ b/pymongo/synchronous/command_encoder.py @@ -17,7 +17,8 @@ This builds the wire-protocol message for a single command -- applying read preference, read concern, collation, ``$clusterTime``, auto-encryption, CSOT, and OP_MSG encoding -- then hands it to -:func:`pymongo.command_runner.run_command` for the network round +:func:`pymongo.command_runner.run_acknowledged_command` for the +network round trip. The raw socket I/O lives in :mod:`pymongo.network_layer`. """ from __future__ import annotations @@ -37,7 +38,10 @@ from pymongo.compression_support import _NO_COMPRESSION from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate -from pymongo.synchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.synchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) if TYPE_CHECKING: from bson import CodecOptions @@ -173,7 +177,7 @@ def command( speculative_hello=speculative_hello, ) else: - docs, _, _ = run_command( + docs, _, _ = run_acknowledged_command( conn, spec, dbname, diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index 70ee8bab22..34ce88eafd 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -15,8 +15,9 @@ """The shared code path for executing a command over a connection. Every database operation runs its network round trip through one of three -public entry points -- :func:`run_command` (acknowledged commands and bulk -write batches), :func:`run_unacknowledged_command` (unacknowledged writes), and +public entry points -- :func:`run_acknowledged_command` (acknowledged commands +and bulk write batches), :func:`run_unacknowledged_command` (unacknowledged +writes), and :func:`run_cursor_command` (cursor ``find``/``getMore`` operations) -- each of which wraps the private :func:`_run_command`. ``_run_command`` owns the entire shared skeleton: command logging, APM event publishing, ``send``/``receive``, @@ -100,7 +101,7 @@ def _run_command( ) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. - This is the shared implementation behind :func:`run_command`, + This is the shared implementation behind :func:`run_acknowledged_command`, :func:`run_unacknowledged_command`, and :func:`run_cursor_command`. Those three public entry points each fix the transport and response-shaping flags for their command type; the bare kwargs here should not be set directly by @@ -342,7 +343,7 @@ def _run_command( return docs, reply, duration -def run_command( +def run_acknowledged_command( conn: Connection, cmd: MutableMapping[str, Any], dbname: str,