Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 52 additions & 150 deletions pymongo/asynchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from __future__ import annotations

import copy
import datetime
import logging
from collections.abc import MutableMapping
from itertools import islice
from typing import (
Expand All @@ -37,6 +35,10 @@
from bson.raw_bson import RawBSONDocument
from pymongo import _csot, common
from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern
from pymongo.asynchronous.command_runner import (
run_acknowledged_command,
run_unacknowledged_command,
)
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.bulk_shared import (
_COMMANDS,
Expand All @@ -57,14 +59,11 @@
OperationFailure,
)
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import (
_DELETE,
_INSERT,
_UPDATE,
_BulkWriteContext,
_convert_exception,
_convert_write_result,
_EncryptedBulkWriteContext,
_randint,
)
Expand Down Expand Up @@ -250,81 +249,36 @@ async def write_command(
docs: list[Mapping[str, Any]],
client: AsyncMongoClient[Any],
) -> dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""
"""Run a batch write command, returning the response as a dict."""
cmd[bwc.field] = docs
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.STARTED,
clientId=client._topology_settings._topology_id,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._start(cmd, request_id, docs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_BulkWriteContext._start, _BulkWriteContext._succeed, _BulkWriteContext._fail, and _ClientBulkWriteContext._start are all dead code.

try:
reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc]
duration = datetime.datetime.now() - bwc.start_time
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.SUCCEEDED,
clientId=client._topology_settings._topology_id,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
result_docs, _, _ = await run_acknowledged_command(
bwc.conn, # type: ignore[arg-type]
cmd,
bwc.db_name,
request_id,
msg,
client=client,
session=bwc.session, # type: ignore[arg-type]
listeners=bwc.listeners,
address=bwc.conn.address,
start=bwc.start_time,
codec_options=bwc.codec,
op_id=bwc.op_id,
command_name=bwc.name,
use_conn_transport=True,
process_response=False,
decrypt_reply=False,
)
reply = result_docs[0]
# Process the response from the server.
await client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.FAILED,
clientId=client._topology_settings._topology_id,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)

if bwc.publish:
bwc._fail(request_id, failure, duration)
# Process the response from the server.
if isinstance(exc, (NotPrimaryError, OperationFailure)):
await client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
raise
return reply # type: ignore[return-value]
return reply

async def unack_write(
self,
Expand All @@ -336,83 +290,31 @@ async def unack_write(
docs: list[Mapping[str, Any]],
client: AsyncMongoClient[Any],
) -> Optional[Mapping[str, Any]]:
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.STARTED,
clientId=client._topology_settings._topology_id,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
cmd = bwc._start(cmd, request_id, docs)
try:
result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override]
duration = datetime.datetime.now() - bwc.start_time
if result is not None:
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_convert_write_result is also dead code now.

else:
# Comply with APM spec.
reply = {"ok": 1}
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.SUCCEEDED,
clientId=client._topology_settings._topology_id,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration)
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, OperationFailure):
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
elif isinstance(exc, NotPrimaryError):
failure = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.FAILED,
clientId=client._topology_settings._topology_id,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if bwc.publish:
assert bwc.start_time is not None
bwc._fail(request_id, failure, duration)
raise
return result # type: ignore[return-value]
"""Send an unacknowledged batch write command."""
# Historically the STARTED log omits the documents while the published
# CommandStartedEvent includes them, so log ``cmd`` but publish a copy
# carrying the ``docs`` field.
published = dict(cmd)
published[bwc.field] = docs
await run_unacknowledged_command(
bwc.conn, # type: ignore[arg-type]
cmd,
bwc.db_name,
request_id,
msg,
client=client,
session=bwc.session, # type: ignore[arg-type]
listeners=bwc.listeners,
address=bwc.conn.address,
start=bwc.start_time,
codec_options=bwc.codec,
op_id=bwc.op_id,
command_name=bwc.name,
orig=published,
use_conn_transport=True,
max_doc_size=max_doc_size,
)
return None

async def _execute_batch_unack(
self,
Expand Down Expand Up @@ -484,7 +386,7 @@ async def _execute_command(
run = self.current_run

# AsyncConnection.command validates the session, but we use
# AsyncConnection.write_command
# run_acknowledged_command/run_unacknowledged_command.
conn.validate_session(client, session)
last_run = False

Expand Down
Loading
Loading