diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 2f04431eb4..0a91ecf67f 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -20,8 +20,6 @@ from __future__ import annotations import copy -import datetime -import logging from collections.abc import Iterator, Mapping, MutableMapping from itertools import islice from typing import ( @@ -34,9 +32,9 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common -from pymongo.asynchronous.client_session import ( - AsyncClientSession, - _validate_session_write_concern, +from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern +from pymongo.asynchronous.command_runner import ( + run_bulk_write_command, ) from pymongo.asynchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( @@ -54,18 +52,14 @@ from pymongo.errors import ( ConfigurationError, InvalidOperation, - NotPrimaryError, OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _DELETE, _INSERT, _UPDATE, _BulkWriteContext, - _convert_exception, - _convert_write_result, _EncryptedBulkWriteContext, _randint, ) @@ -251,83 +245,16 @@ async def write_command( docs: list[Mapping[str, Any]], client: AsyncMongoClient[Any], ) -> dict[str, Any]: - """A proxy for SocketInfo.write_command that handles event publishing.""" + """Run a batch write command, returning the response as a dict.""" cmd[bwc.field] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._start(cmd, request_id, docs) - try: - if bwc.session is not None and bwc.session._starting_transaction: - bwc.session._transaction.set_in_progress() - reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] - duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] - await client._process_response(reply, bwc.session) # type: ignore[arg-type] - except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if bwc.publish: - bwc._fail(request_id, failure, duration) - # Process the response from the server. - if isinstance(exc, (NotPrimaryError, OperationFailure)): - await client._process_response(exc.details, bwc.session) # type: ignore[arg-type] - raise - return reply # type: ignore[return-value] + result_docs, _, _ = await run_bulk_write_command( + bwc, # type: ignore[arg-type] + cmd, + request_id, + msg, + client=client, + ) + return result_docs[0] async def unack_write( self, @@ -339,83 +266,23 @@ async def unack_write( docs: list[Mapping[str, Any]], client: AsyncMongoClient[Any], ) -> Optional[Mapping[str, Any]]: - """A proxy for AsyncConnection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - cmd = bwc._start(cmd, request_id, docs) - try: - result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] - duration = datetime.datetime.now() - bwc.start_time - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) - raise - return result # type: ignore[return-value] + """Send an unacknowledged batch write command.""" + # Historically the STARTED log omits the documents while the published + # CommandStartedEvent includes them, so log ``cmd`` but publish a copy + # carrying the ``docs`` field. + published = dict(cmd) + published[bwc.field] = docs + await run_bulk_write_command( + bwc, # type: ignore[arg-type] + cmd, + request_id, + msg, + client=client, + orig=published, + max_doc_size=max_doc_size, + unacknowledged=True, + ) + return None async def _execute_batch_unack( self, @@ -487,7 +354,7 @@ async def _execute_command( run = self.current_run # AsyncConnection.command validates the session, but we use - # AsyncConnection.write_command + # run_bulk_write_command. conn.validate_session(client, session) last_run = False diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 45fbc403c0..dd04bc6296 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -20,8 +20,6 @@ from __future__ import annotations import copy -import datetime -import logging from collections.abc import Mapping, MutableMapping from itertools import islice from typing import ( @@ -40,6 +38,9 @@ ) from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.command_runner import ( + run_bulk_write_command, +) from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.helpers import _handle_reauth @@ -65,12 +66,9 @@ WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _ClientBulkWriteContext, _convert_client_bulk_exception, - _convert_exception, - _convert_write_result, _randint, ) from pymongo.read_preferences import ReadPreference @@ -238,87 +236,21 @@ async def write_command( ns_docs: list[Mapping[str, Any]], client: AsyncMongoClient[Any], ) -> dict[str, Any]: - """A proxy for AsyncConnection.write_command that handles event publishing.""" + """Run a client-level batch write command, returning the response as a dict.""" cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._start(cmd, request_id, op_docs, ns_docs) try: - if bwc.session is not None and bwc.session._starting_transaction: - bwc.session._transaction.set_in_progress() - reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] - duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] - # Process the response from the server. - await self.client._process_response(reply, bwc.session) # type: ignore[arg-type] + result_docs, _, _ = await run_bulk_write_command( + bwc, # type: ignore[arg-type] + cmd, + request_id, + msg, # type: ignore[arg-type] + client=client, + ) + reply = result_docs[0] except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if bwc.publish: - bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} - # Process the response from the server. - if isinstance(exc, OperationFailure): - await self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] - else: - await self.client._process_response({}, bwc.session) # type: ignore[arg-type] return reply # type: ignore[return-value] async def unack_write( @@ -331,81 +263,26 @@ async def unack_write( ns_docs: list[Mapping[str, Any]], client: AsyncMongoClient[Any], ) -> Optional[Mapping[str, Any]]: - """A proxy for AsyncConnection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - cmd = bwc._start(cmd, request_id, op_docs, ns_docs) + """Send an unacknowledged client-level batch write command.""" + # Historically the STARTED log omits the ops/nsInfo while the published + # CommandStartedEvent includes them, so log ``cmd`` but publish a copy + # carrying those fields. + published = dict(cmd) + published["ops"] = op_docs + published["nsInfo"] = ns_docs try: - result = await bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override] - duration = datetime.datetime.now() - bwc.start_time - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) + result_docs, _, _ = await run_bulk_write_command( + bwc, # type: ignore[arg-type] + cmd, + request_id, + msg, + client=client, + orig=published, + max_doc_size=bwc.max_bson_size, + unacknowledged=True, + ) + reply: Mapping[str, Any] = result_docs[0] except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} return reply @@ -504,7 +381,7 @@ async def _execute_command( listeners = self.client._event_listeners # AsyncConnection.command validates the session, but we use - # AsyncConnection.write_command + # run_bulk_write_command. conn.validate_session(self.client, session) bwc = self.bulk_ctx_class( diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py new file mode 100644 index 0000000000..b2e864a749 --- /dev/null +++ b/pymongo/asynchronous/command_runner.py @@ -0,0 +1,566 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Encoding and execution of commands over a connection. + +The public :func:`command` entry point applies read preference, read concern, +collation, ``$clusterTime``, auto-encryption, and CSOT to a command spec, +encodes it as an OP_MSG message, and then delegates to :func:`run_command`. + +Three public entry points each wrap the private :func:`_run_command`: + +- :func:`run_command` — standard network-transport commands (acknowledged or + unacknowledged). Called by :func:`command`. +- :func:`run_bulk_write_command` — collection-level and client-level bulk write + batches (connection transport; pre-encrypted so decryption is skipped). + Callers: ``bulk.py``, ``client_bulk.py``. +- :func:`run_cursor_command` — cursor ``find``/``getMore`` operations + (connection transport, exhaust-cursor handling). Caller: ``server.py``. + +:func:`_run_command` owns the entire shared skeleton: command logging, APM +event publishing, ``send``/``receive``, ``$clusterTime`` gossip, +``_process_response``, ``_check_command_response``, failure conversion, and +auto-encryption decryption. Each public wrapper hardcodes the transport and +response-shaping flags for its command type so callers only pass what varies. +""" + +from __future__ import annotations + +import datetime +import logging +from collections.abc import Mapping, MutableMapping, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Optional, + Protocol, + Union, + cast, +) + +from bson import _decode_all_selective +from pymongo import _csot, helpers_shared, message +from pymongo.compression_support import _NO_COMPRESSION +from pymongo.errors import NotPrimaryError, OperationFailure +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _convert_exception, _OpMsg +from pymongo.monitoring import _is_speculative_authenticate +from pymongo.network_layer import async_receive_message, async_sendall + +if TYPE_CHECKING: + from bson import CodecOptions + from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.mongo_client import AsyncMongoClient + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.monitoring import _EventListeners + from pymongo.pool_options import PoolOptions + from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class _BulkWriteContextProto(Protocol): + """Structural interface for bulk write context objects passed to :func:`run_bulk_write_command`.""" + + conn: AsyncConnection + db_name: str + session: Optional[AsyncClientSession] + listeners: Optional[_EventListeners] + start_time: datetime.datetime + codec: CodecOptions[Any] + op_id: int + name: str + + +async def _run_command( + conn: AsyncConnection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[AsyncMongoClient[Any]], + session: Optional[AsyncClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + command_name: Optional[str] = None, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + pool_opts: Optional[PoolOptions] = None, + unacknowledged: bool = False, + speculative_hello: bool = False, + ensure_db: bool = False, + process_response: bool = True, + decrypt_reply: bool = True, + use_conn_transport: bool = False, + max_doc_size: int = 0, + more_to_come: bool = False, + set_conn_more_to_come: bool = True, + unpack_res: Optional[Callable[..., Any]] = None, + cursor_id: Optional[int] = None, +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: + """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. + + Private shared implementation. Use :func:`run_command`, + :func:`run_bulk_write_command`, or :func:`run_cursor_command` instead. + + It publishes the + ``STARTED``/``SUCCEEDED``/``FAILED`` command log and APM events, performs + the network round trip, gossips ``$clusterTime``, runs + ``client._process_response`` and ``_check_command_response``, and decrypts + the reply when auto-encryption is enabled. + + :param conn: The AsyncConnection to send on. + :param cmd: The command document, used for the ``STARTED`` log/APM event. + :param dbname: The database the command runs against. + :param request_id: The request id of the encoded message (``0`` when + ``more_to_come`` and no message is sent). + :param msg: The encoded bytes to send (ignored when ``more_to_come``). + :param client: The AsyncMongoClient, for ``$clusterTime`` gossip, logging, + and decryption. ``None`` disables those steps (e.g. during handshake). + :param session: The session to update from the response. + :param listeners: The event listeners, or ``None`` to disable APM. + :param address: The (host, port) of ``conn`` for APM events. + :param start: The ``datetime`` the operation began, for duration timing. + :param codec_options: The CodecOptions used to decode the reply. + :param user_fields: Response fields decoded with the codec's TypeDecoders. + :param orig: The command document published in the ``STARTED`` APM event; + defaults to ``cmd`` (differs only when the wire command was mutated, + e.g. with a read preference or after encryption). + :param op_id: The APM operation id; defaults to ``request_id``. + :param command_name: The command name for the ``SUCCEEDED``/``FAILED`` APM + events; defaults to the first key of ``cmd``. + :param check: Raise OperationFailure on a command error. + :param allowable_errors: Errors to ignore when ``check`` is True. + :param parse_write_concern_error: Parse the ``writeConcernError`` field. + :param pool_opts: PoolOptions forwarded to ``_check_command_response`` (the + cursor path uses this in place of ``allowable_errors``). + :param unacknowledged: True for an unacknowledged write: send only and fake + an ``{"ok": 1}`` reply. + :param speculative_hello: True if the command carried speculative auth, for + APM redaction. + :param ensure_db: Add ``$db`` to the published command if missing (cursor + path), after the ``STARTED`` log has been emitted. + :param process_response: Run ``client._process_response`` on the response + document before ``_check_command_response`` and APM/log events. + :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; + the bulk paths pass False (their commands are encrypted up front). + :param use_conn_transport: Send/receive via ``conn.send_message`` / + ``conn.receive_message`` instead of the raw ``async_sendall`` / + ``async_receive_message`` (network path). + :param max_doc_size: The largest document size, for ``conn.send_message``. + :param more_to_come: Receive only, without sending (exhaust ``getMore``). + :param set_conn_more_to_come: Store ``reply.more_to_come`` on ``conn`` (the + network/streaming-monitor path); the cursor path manages exhaust + separately and must leave ``conn.more_to_come`` untouched. + :param unpack_res: A callable decoding the wire response (cursor path); when + ``None`` the reply's own ``unpack_response`` is used. + :param cursor_id: The cursor id passed to ``unpack_res``. + """ + name = next(iter(cmd)) + if command_name is None: + command_name = name + if orig is None: + orig = cmd + publish = listeners is not None and listeners.enabled_for_commands + + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + if publish: + assert listeners is not None + assert address is not None + if ensure_db and "$db" not in orig: + orig["$db"] = dbname + listeners.publish_command_start( + orig, + dbname, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + ) + + reply: Optional[_OpMsg] + try: + if more_to_come: + reply = await conn.receive_message(None) + elif unacknowledged: + if use_conn_transport: + conn._raise_if_not_writable() + await conn.send_message(msg, max_doc_size) + else: + await async_sendall(conn.conn.get_conn, msg) + # Unacknowledged, fake a successful command response. + reply = None + docs: list[dict[str, Any]] = [{"ok": 1}] + elif use_conn_transport: + if session is not None and session._starting_transaction: + session._transaction.set_in_progress() + await conn.send_message(msg, max_doc_size) + reply = await conn.receive_message(request_id) + else: + await async_sendall(conn.conn.get_conn, msg) + reply = await async_receive_message(conn, request_id) + + if reply is not None: + if set_conn_more_to_come: + conn.more_to_come = reply.more_to_come + if unpack_res is not None: + docs = unpack_res( + reply, + cursor_id, + codec_options, + user_fields=user_fields, + ) + else: + docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if process_response and client: + await client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + pool_opts=pool_opts, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + command_name, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + + duration = datetime.datetime.now() - start + published_reply: _DocumentOut + published_reply = docs[0] + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=published_reply, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + speculative_authenticate="speculativeAuthenticate" in orig, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + published_reply, + command_name, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + if client and client._encrypter and reply and decrypt_reply: + decrypted = await client._encrypter.decrypt(reply.raw_command_response()) + docs = cast( + "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) + ) + + return docs, reply, duration + + +async def run_bulk_write_command( + bwc: _BulkWriteContextProto, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + *, + client: Optional[AsyncMongoClient[Any]], + orig: Optional[MutableMapping[str, Any]] = None, + max_doc_size: int = 0, + unacknowledged: bool = False, +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: + """Send a bulk write batch and return ``(docs, reply, duration)``. + + :param bwc: A bulk write context supplying the connection, session, listeners, etc. + :param max_doc_size: The largest document size; passed to ``conn.send_message``. + :param unacknowledged: When ``True``, send only and fake an ``{"ok": 1}`` reply. + """ + return await _run_command( + bwc.conn, + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + orig=orig, + max_doc_size=max_doc_size, + unacknowledged=unacknowledged, + use_conn_transport=True, + decrypt_reply=False, + set_conn_more_to_come=False, + process_response=not unacknowledged, + ) + + +async def run_cursor_command( + conn: AsyncConnection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[AsyncMongoClient[Any]], + session: Optional[AsyncClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + command_name: str, + user_fields: Optional[Mapping[str, Any]] = None, + pool_opts: Optional[PoolOptions] = None, + max_doc_size: int = 0, + more_to_come: bool = False, + unpack_res: Optional[Callable[..., Any]] = None, + cursor_id: Optional[int] = None, +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: + """Run a cursor ``find``/``getMore`` operation over ``conn``. + + :param more_to_come: Receive only, without sending (exhaust ``getMore``). + :param unpack_res: A callable decoding the wire response. + :param cursor_id: The cursor id passed to ``unpack_res``. + + See :func:`_run_command` for the remaining parameters. + """ + return await _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + command_name=command_name, + pool_opts=pool_opts, + ensure_db=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + more_to_come=more_to_come, + set_conn_more_to_come=False, + unpack_res=unpack_res, + cursor_id=cursor_id, + ) + + +async def run_command( + conn: AsyncConnection, + dbname: str, + spec: MutableMapping[str, Any], + is_mongos: bool, # noqa: ARG001 + read_preference: Optional[_ServerMode], + codec_options: CodecOptions[_DocumentType], + session: Optional[AsyncClientSession], + client: Optional[AsyncMongoClient[Any]], + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + address: Optional[_Address] = None, + listeners: Optional[_EventListeners] = None, + max_bson_size: Optional[int] = None, + read_concern: Optional[ReadConcern] = None, + parse_write_concern_error: bool = False, + collation: Optional[_CollationIn] = None, + compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, + unacknowledged: bool = False, + user_fields: Optional[Mapping[str, Any]] = None, + exhaust_allowed: bool = False, + write_concern: Optional[WriteConcern] = None, +) -> _DocumentType: + """Encode and execute a command over ``conn``, or raise socket.error. + + Applies read preference, read concern, collation, ``$clusterTime``, + auto-encryption, and CSOT to ``spec``, encodes it as an OP_MSG message, + then performs the network round trip and response processing. + + :param conn: a AsyncConnection instance + :param dbname: name of the database on which to run the command + :param spec: a command document as an ordered dict type, eg SON. + :param is_mongos: are we connected to a mongos? + :param read_preference: a read preference + :param codec_options: a CodecOptions instance + :param session: optional AsyncClientSession instance. + :param client: optional AsyncMongoClient instance for updating $clusterTime. + :param check: raise OperationFailure if there are errors + :param allowable_errors: errors to ignore if `check` is True + :param address: the (host, port) of `conn` + :param listeners: An instance of :class:`~pymongo.monitoring.EventListeners` + :param max_bson_size: The maximum encoded bson size for this server + :param read_concern: The read concern for this command. + :param parse_write_concern_error: Whether to parse the ``writeConcernError`` + field in the command response. + :param collation: The collation for this command. + :param compression_ctx: optional compression Context. + :param unacknowledged: True if this is an unacknowledged command. + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed. + """ + name = next(iter(spec)) + speculative_hello = False + + # Publish the original command document, perhaps with lsid and $clusterTime. + orig = spec + if read_concern and not (session and session.in_transaction): + if read_concern.level: + spec["readConcern"] = read_concern.document + if session: + session._update_read_concern(spec, conn) + if collation is not None: + spec["collation"] = collation + + publish = listeners is not None and listeners.enabled_for_commands + start = datetime.datetime.now() + if publish: + speculative_hello = _is_speculative_authenticate(name, spec) + + if compression_ctx and name.lower() in _NO_COMPRESSION: + compression_ctx = None + + if client and client._encrypter and not client._encrypter._bypass_auto_encryption: + spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options) + + # Support CSOT + if client: + conn.apply_timeout(client, spec) + _csot.apply_write_concern(spec, write_concern) + + flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 + flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 + request_id, msg, size, max_doc_size = message._op_msg( + flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx + ) + # If this is an unacknowledged write then make sure the encoded doc(s) + # are small enough, otherwise rely on the server to return an error. + if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: + message._raise_document_too_large(name, size, max_bson_size) + + if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: + message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) + docs, _, _ = await _run_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + speculative_hello=speculative_hello, + unacknowledged=unacknowledged, + process_response=not unacknowledged, + decrypt_reply=not unacknowledged, + ) + return docs[0] # type: ignore[return-value] diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py deleted file mode 100644 index 16bca3e10e..0000000000 --- a/pymongo/asynchronous/network.py +++ /dev/null @@ -1,286 +0,0 @@ -# Copyright 2015-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Internal network layer helper methods.""" - -from __future__ import annotations - -import datetime -import logging -from collections.abc import Mapping, MutableMapping, Sequence -from typing import ( - TYPE_CHECKING, - Any, - Optional, - Union, - cast, -) - -from bson import _decode_all_selective -from pymongo import _csot, helpers_shared, message -from pymongo.compression_support import _NO_COMPRESSION -from pymongo.errors import ( - NotPrimaryError, - OperationFailure, -) -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.message import _OpMsg -from pymongo.monitoring import _is_speculative_authenticate -from pymongo.network_layer import ( - async_receive_message, - async_sendall, -) - -if TYPE_CHECKING: - from bson import CodecOptions - from pymongo.asynchronous.client_session import AsyncClientSession - from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.pool import AsyncConnection - from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext - from pymongo.monitoring import _EventListeners - from pymongo.read_concern import ReadConcern - from pymongo.read_preferences import _ServerMode - from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType - from pymongo.write_concern import WriteConcern - -_IS_SYNC = False - - -async def command( - conn: AsyncConnection, - dbname: str, - spec: MutableMapping[str, Any], - is_mongos: bool, # noqa: ARG001 - read_preference: Optional[_ServerMode], - codec_options: CodecOptions[_DocumentType], - session: Optional[AsyncClientSession], - client: Optional[AsyncMongoClient[Any]], - check: bool = True, - allowable_errors: Optional[Sequence[Union[str, int]]] = None, - address: Optional[_Address] = None, - listeners: Optional[_EventListeners] = None, - max_bson_size: Optional[int] = None, - read_concern: Optional[ReadConcern] = None, - parse_write_concern_error: bool = False, - collation: Optional[_CollationIn] = None, - compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, - unacknowledged: bool = False, - user_fields: Optional[Mapping[str, Any]] = None, - exhaust_allowed: bool = False, - write_concern: Optional[WriteConcern] = None, -) -> _DocumentType: - """Execute a command over the socket, or raise socket.error. - - :param conn: a AsyncConnection instance - :param dbname: name of the database on which to run the command - :param spec: a command document as an ordered dict type, eg SON. - :param is_mongos: are we connected to a mongos? - :param read_preference: a read preference - :param codec_options: a CodecOptions instance - :param session: optional AsyncClientSession instance. - :param client: optional AsyncMongoClient instance for updating $clusterTime. - :param check: raise OperationFailure if there are errors - :param allowable_errors: errors to ignore if `check` is True - :param address: the (host, port) of `conn` - :param listeners: An instance of :class:`~pymongo.monitoring.EventListeners` - :param max_bson_size: The maximum encoded bson size for this server - :param read_concern: The read concern for this command. - :param parse_write_concern_error: Whether to parse the ``writeConcernError`` - field in the command response. - :param collation: The collation for this command. - :param compression_ctx: optional compression Context. - :param unacknowledged: True if this is an unacknowledged command. - :param user_fields: Response fields that should be decoded - using the TypeDecoders from codec_options, passed to - bson._decode_all_selective. - :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed. - """ - name = next(iter(spec)) - speculative_hello = False - - # Publish the original command document, perhaps with lsid and $clusterTime. - orig = spec - if read_concern and not (session and session.in_transaction): - if read_concern.level: - spec["readConcern"] = read_concern.document - if session: - session._update_read_concern(spec, conn) - if collation is not None: - spec["collation"] = collation - - publish = listeners is not None and listeners.enabled_for_commands - start = datetime.datetime.now() - if publish: - speculative_hello = _is_speculative_authenticate(name, spec) - - if compression_ctx and name.lower() in _NO_COMPRESSION: - compression_ctx = None - - if client and client._encrypter and not client._encrypter._bypass_auto_encryption: - spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options) - - # Support CSOT - if client: - conn.apply_timeout(client, spec) - _csot.apply_write_concern(spec, write_concern) - - flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 - flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 - request_id, msg, size, max_doc_size = message._op_msg( - flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx - ) - # If this is an unacknowledged write then make sure the encoded doc(s) - # are small enough, otherwise rely on the server to return an error. - if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: - message._raise_document_too_large(name, size, max_bson_size) - - if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: - message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=spec, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - await async_sendall(conn.conn.get_conn, msg) - if unacknowledged: - # Unacknowledged, fake a successful command response. - reply = None - response_doc: _DocumentOut = {"ok": 1} - else: - reply = await async_receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields - ) - - response_doc = unpacked_docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - await client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - ) - except Exception as exc: - duration = datetime.datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = message._convert_exception(exc) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = datetime.datetime.now() - start - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=response_doc, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - speculative_authenticate="speculativeAuthenticate" in orig, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - response_doc, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - speculative_hello=speculative_hello, - database_name=dbname, - ) - - if client and client._encrypter and reply: - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) - response_doc = cast( - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] - ) - - return response_doc # type: ignore[return-value] diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index ae7bdfea71..be3b3704a2 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -36,8 +36,8 @@ from bson import DEFAULT_CODEC_OPTIONS from pymongo import _csot, helpers_shared from pymongo.asynchronous.client_session import _validate_session_write_concern +from pymongo.asynchronous.command_runner import run_command from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.asynchronous.network import command from pymongo.common import ( MAX_BSON_SIZE, MAX_MESSAGE_SIZE, @@ -390,11 +390,12 @@ async def command( self.send_cluster_time(spec, session, client) listeners = self.listeners if publish_events else None unacknowledged = bool(write_concern and not write_concern.acknowledged) - self._raise_if_not_writable(unacknowledged) + if unacknowledged: + self._raise_if_not_writable() try: if session is not None and session._starting_transaction: session._transaction.set_in_progress() - return await command( + return await run_command( self, dbname, spec, @@ -451,43 +452,11 @@ async def receive_message(self, request_id: Optional[int]) -> _OpMsg: except BaseException as error: await self._raise_connection_failure(error) - def _raise_if_not_writable(self, unacknowledged: bool) -> None: - """Raise NotPrimaryError on unacknowledged write if this socket is not - writable. - """ - if unacknowledged and not self.is_writable: - # Write won't succeed, bail as if we'd received a not primary error. + def _raise_if_not_writable(self) -> None: + """Raise NotPrimaryError if this connection is not writable.""" + if not self.is_writable: raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) - async def unack_write(self, msg: bytes, max_doc_size: int) -> None: - """Send unack OP_MSG. - - Can raise ConnectionFailure or InvalidDocument. - - :param msg: bytes, an OP_MSG message. - :param max_doc_size: size in bytes of the largest document in `msg`. - """ - self._raise_if_not_writable(True) - await self.send_message(msg, max_doc_size) - - async def write_command( - self, request_id: int, msg: bytes, codec_options: CodecOptions[Mapping[str, Any]] - ) -> dict[str, Any]: - """Send "insert" etc. command, returning response as a dict. - - Can raise ConnectionFailure or OperationFailure. - - :param request_id: an int. - :param msg: bytes, the command message. - """ - await self.send_message(msg, 0) - reply = await self.receive_message(request_id) - result = reply.command_response(codec_options) - - # Raises NotPrimaryError or OperationFailure. - helpers_shared._check_command_response(result, self.max_wire_version) - return result - async def authenticate(self, reauthenticate: bool = False) -> None: """Authenticate to the server if needed. diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 1ca2689229..c2740caa16 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -27,18 +27,14 @@ Union, ) -from bson import _decode_all_selective +from pymongo.asynchronous.command_runner import run_cursor_command from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers_shared import _check_command_response from pymongo.logger import ( - _COMMAND_LOGGER, _SDAM_LOGGER, - _CommandStatusMessage, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query +from pymongo.message import _GetMore, _OpMsg, _Query from pymongo.response import PinnedResponse, Response if TYPE_CHECKING: @@ -159,7 +155,6 @@ async def run_operation( :param client: An AsyncMongoClient instance. """ assert listeners is not None - publish = listeners.enabled_for_commands start = datetime.now() use_cmd = operation.use_command(conn) @@ -167,144 +162,35 @@ async def run_operation( cmd, dbn = await self.operation_to_command(operation, conn, use_cmd) if more_to_come: request_id = 0 + data = b"" + max_doc_size = 0 else: message = operation.get_message(read_preference, conn, use_cmd) request_id, data, max_doc_size = self._split_message(message) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - - if publish: - if "$db" not in cmd: - cmd["$db"] = dbn - assert listeners is not None - listeners.publish_command_start( - cmd, - dbn, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - if more_to_come: - reply = await conn.receive_message(None) - else: - if operation.session is not None and operation.session._starting_transaction: - operation.session._transaction.set_in_progress() - await conn.send_message(data, max_doc_size) - reply = await conn.receive_message(request_id) - - # Unpack and check for command errors. - if use_cmd: - user_fields = _CURSOR_DOC_FIELDS - legacy_response = False - else: - user_fields = None - legacy_response = True - docs = unpack_res( - reply, - operation.cursor_id, - operation.codec_options, - legacy_response=legacy_response, - user_fields=user_fields, - ) - if use_cmd: - first = docs[0] - await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] - _check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type] - except Exception as exc: - duration = datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - listeners.publish_command_failure( - duration, - failure, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - raise - duration = datetime.now() - start - # Must publish in find / getMore / explain command response - # format. - res = docs[0] - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=res, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - listeners.publish_command_success( - duration, - res, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - - # Decrypt response. - client = operation.client # type: ignore[assignment] - if client and client._encrypter: - if use_cmd: - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) - docs = _decode_all_selective(decrypted, operation.codec_options, user_fields) + user_fields = _CURSOR_DOC_FIELDS if use_cmd else None + + docs, reply, duration = await run_cursor_command( + conn, + cmd, + dbn, + request_id, + data, + client=client, + session=operation.session, # type: ignore[arg-type] + listeners=listeners, + address=conn.address, + start=start, + codec_options=operation.codec_options, + user_fields=user_fields, + command_name=operation.name, + pool_opts=conn.opts, + max_doc_size=max_doc_size, + more_to_come=bool(more_to_come), + unpack_res=unpack_res, + cursor_id=operation.cursor_id, + ) + assert reply is not None response: Response @@ -326,7 +212,7 @@ async def run_operation( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs, + docs=docs, # type: ignore[arg-type] more_to_come=more_to_come, ) else: @@ -336,7 +222,7 @@ async def run_operation( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs, + docs=docs, # type: ignore[arg-type] ) return response diff --git a/pymongo/message.py b/pymongo/message.py index cb1d9a4184..bcd2810895 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -68,7 +68,6 @@ _AgnosticClientSession, _AgnosticConnection, _AgnosticMongoClient, - _DocumentOut, ) @@ -145,34 +144,6 @@ def _convert_client_bulk_exception(exception: Exception) -> dict[str, Any]: } -def _convert_write_result( - operation: str, command: Mapping[str, Any], result: Mapping[str, Any] -) -> dict[str, Any]: - """Convert a legacy write result to write command format.""" - # Based on _merge_legacy from bulk.py - affected = result.get("n", 0) - res = {"ok": 1, "n": affected} - errmsg = result.get("errmsg", result.get("err", "")) - if errmsg: - # The write was successful on at least the primary so don't return. - if result.get("wtimeout"): - res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} - else: - # The write failed. - error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} - if "errInfo" in result: - error["errInfo"] = result["errInfo"] - res["writeErrors"] = [error] - return res - if operation == "insert": - # GLE result for insert is always 0 in most MongoDB versions. - res["n"] = len(command["documents"]) - elif operation == "update": - if "upserted" in result: - res["upserted"] = [{"index": 0, "_id": result["upserted"]}] - return res - - _OPTIONS = { "tailable": 2, "oplogReplay": 8, @@ -479,7 +450,6 @@ class _BulkWriteContextBase: "name", "op_id", "op_type", - "publish", "session", "start_time", ) @@ -499,7 +469,6 @@ def __init__( self.conn = conn self.op_id = operation_id self.listeners = listeners - self.publish = listeners.enabled_for_commands self.name = cmd_name self.field = _FIELD_MAP[self.name] self.start_time = datetime.datetime.now() @@ -531,34 +500,6 @@ def max_split_size(self) -> int: """The maximum size of a BSON command before batch splitting.""" return self.max_bson_size - def _succeed(self, request_id: int, reply: _DocumentOut, duration: datetime.timedelta) -> None: - """Publish a CommandSucceededEvent.""" - self.listeners.publish_command_success( - duration, - reply, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - - def _fail(self, request_id: int, failure: _DocumentOut, duration: datetime.timedelta) -> None: - """Publish a CommandFailedEvent.""" - self.listeners.publish_command_failure( - duration, - failure, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - class _BulkWriteContext(_BulkWriteContextBase): """A wrapper around AsyncConnection/Connection for use with the collection-level bulk write API.""" @@ -598,22 +539,6 @@ def batch_command( raise InvalidOperation("cannot do an empty bulk write") return request_id, msg, to_send - def _start( - self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] - ) -> MutableMapping[str, Any]: - """Publish a CommandStartedEvent.""" - cmd[self.field] = docs - self.listeners.publish_command_start( - cmd, - self.db_name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - ) - return cmd - class _EncryptedBulkWriteContext(_BulkWriteContext): __slots__ = () @@ -858,27 +783,6 @@ def batch_command( raise InvalidOperation("cannot do an empty bulk write") return request_id, msg, to_send_ops, to_send_ns - def _start( - self, - cmd: MutableMapping[str, Any], - request_id: int, - op_docs: list[Mapping[str, Any]], - ns_docs: list[Mapping[str, Any]], - ) -> MutableMapping[str, Any]: - """Publish a CommandStartedEvent.""" - cmd["ops"] = op_docs - cmd["nsInfo"] = ns_docs - self.listeners.publish_command_start( - cmd, - self.db_name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - ) - return cmd - _OP_MSG_OVERHEAD = 1000 diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 5dfcec27c5..8b3f8d320b 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -20,8 +20,6 @@ from __future__ import annotations import copy -import datetime -import logging from collections.abc import Iterator, Mapping, MutableMapping from itertools import islice from typing import ( @@ -49,25 +47,21 @@ from pymongo.errors import ( ConfigurationError, InvalidOperation, - NotPrimaryError, OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _DELETE, _INSERT, _UPDATE, _BulkWriteContext, - _convert_exception, - _convert_write_result, _EncryptedBulkWriteContext, _randint, ) from pymongo.read_preferences import ReadPreference -from pymongo.synchronous.client_session import ( - ClientSession, - _validate_session_write_concern, +from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern +from pymongo.synchronous.command_runner import ( + run_bulk_write_command, ) from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern @@ -251,83 +245,16 @@ def write_command( docs: list[Mapping[str, Any]], client: MongoClient[Any], ) -> dict[str, Any]: - """A proxy for SocketInfo.write_command that handles event publishing.""" + """Run a batch write command, returning the response as a dict.""" cmd[bwc.field] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._start(cmd, request_id, docs) - try: - if bwc.session is not None and bwc.session._starting_transaction: - bwc.session._transaction.set_in_progress() - reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] - duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] - client._process_response(reply, bwc.session) # type: ignore[arg-type] - except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if bwc.publish: - bwc._fail(request_id, failure, duration) - # Process the response from the server. - if isinstance(exc, (NotPrimaryError, OperationFailure)): - client._process_response(exc.details, bwc.session) # type: ignore[arg-type] - raise - return reply # type: ignore[return-value] + result_docs, _, _ = run_bulk_write_command( + bwc, # type: ignore[arg-type] + cmd, + request_id, + msg, + client=client, + ) + return result_docs[0] def unack_write( self, @@ -339,83 +266,23 @@ def unack_write( docs: list[Mapping[str, Any]], client: MongoClient[Any], ) -> Optional[Mapping[str, Any]]: - """A proxy for Connection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - cmd = bwc._start(cmd, request_id, docs) - try: - result = bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] - duration = datetime.datetime.now() - bwc.start_time - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) - raise - return result # type: ignore[return-value] + """Send an unacknowledged batch write command.""" + # Historically the STARTED log omits the documents while the published + # CommandStartedEvent includes them, so log ``cmd`` but publish a copy + # carrying the ``docs`` field. + published = dict(cmd) + published[bwc.field] = docs + run_bulk_write_command( + bwc, # type: ignore[arg-type] + cmd, + request_id, + msg, + client=client, + orig=published, + max_doc_size=max_doc_size, + unacknowledged=True, + ) + return None def _execute_batch_unack( self, @@ -487,7 +354,7 @@ def _execute_command( run = self.current_run # Connection.command validates the session, but we use - # Connection.write_command + # run_bulk_write_command. conn.validate_session(client, session) last_run = False diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 600b20a761..2c180ed6e5 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -20,8 +20,6 @@ from __future__ import annotations import copy -import datetime -import logging from collections.abc import Mapping, MutableMapping from itertools import islice from typing import ( @@ -40,6 +38,9 @@ ) from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.command_runner import ( + run_bulk_write_command, +) from pymongo.synchronous.database import Database from pymongo.synchronous.helpers import _handle_reauth @@ -65,12 +66,9 @@ WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _ClientBulkWriteContext, _convert_client_bulk_exception, - _convert_exception, - _convert_write_result, _randint, ) from pymongo.read_preferences import ReadPreference @@ -238,87 +236,21 @@ def write_command( ns_docs: list[Mapping[str, Any]], client: MongoClient[Any], ) -> dict[str, Any]: - """A proxy for Connection.write_command that handles event publishing.""" + """Run a client-level batch write command, returning the response as a dict.""" cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._start(cmd, request_id, op_docs, ns_docs) try: - if bwc.session is not None and bwc.session._starting_transaction: - bwc.session._transaction.set_in_progress() - reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] - duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] - # Process the response from the server. - self.client._process_response(reply, bwc.session) # type: ignore[arg-type] + result_docs, _, _ = run_bulk_write_command( + bwc, # type: ignore[arg-type] + cmd, + request_id, + msg, # type: ignore[arg-type] + client=client, + ) + reply = result_docs[0] except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if bwc.publish: - bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} - # Process the response from the server. - if isinstance(exc, OperationFailure): - self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] - else: - self.client._process_response({}, bwc.session) # type: ignore[arg-type] return reply # type: ignore[return-value] def unack_write( @@ -331,81 +263,26 @@ def unack_write( ns_docs: list[Mapping[str, Any]], client: MongoClient[Any], ) -> Optional[Mapping[str, Any]]: - """A proxy for Connection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - cmd = bwc._start(cmd, request_id, op_docs, ns_docs) + """Send an unacknowledged client-level batch write command.""" + # Historically the STARTED log omits the ops/nsInfo while the published + # CommandStartedEvent includes them, so log ``cmd`` but publish a copy + # carrying those fields. + published = dict(cmd) + published["ops"] = op_docs + published["nsInfo"] = ns_docs try: - result = bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override] - duration = datetime.datetime.now() - bwc.start_time - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) + result_docs, _, _ = run_bulk_write_command( + bwc, # type: ignore[arg-type] + cmd, + request_id, + msg, + client=client, + orig=published, + max_doc_size=bwc.max_bson_size, + unacknowledged=True, + ) + reply: Mapping[str, Any] = result_docs[0] except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} return reply @@ -502,7 +379,7 @@ def _execute_command( listeners = self.client._event_listeners # Connection.command validates the session, but we use - # Connection.write_command + # run_bulk_write_command. conn.validate_session(self.client, session) bwc = self.bulk_ctx_class( diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py new file mode 100644 index 0000000000..37e512c681 --- /dev/null +++ b/pymongo/synchronous/command_runner.py @@ -0,0 +1,566 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Encoding and execution of commands over a connection. + +The public :func:`command` entry point applies read preference, read concern, +collation, ``$clusterTime``, auto-encryption, and CSOT to a command spec, +encodes it as an OP_MSG message, and then delegates to :func:`run_command`. + +Three public entry points each wrap the private :func:`_run_command`: + +- :func:`run_command` — standard network-transport commands (acknowledged or + unacknowledged). Called by :func:`command`. +- :func:`run_bulk_write_command` — collection-level and client-level bulk write + batches (connection transport; pre-encrypted so decryption is skipped). + Callers: ``bulk.py``, ``client_bulk.py``. +- :func:`run_cursor_command` — cursor ``find``/``getMore`` operations + (connection transport, exhaust-cursor handling). Caller: ``server.py``. + +:func:`_run_command` owns the entire shared skeleton: command logging, APM +event publishing, ``send``/``receive``, ``$clusterTime`` gossip, +``_process_response``, ``_check_command_response``, failure conversion, and +auto-encryption decryption. Each public wrapper hardcodes the transport and +response-shaping flags for its command type so callers only pass what varies. +""" + +from __future__ import annotations + +import datetime +import logging +from collections.abc import Mapping, MutableMapping, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Optional, + Protocol, + Union, + cast, +) + +from bson import _decode_all_selective +from pymongo import _csot, helpers_shared, message +from pymongo.compression_support import _NO_COMPRESSION +from pymongo.errors import NotPrimaryError, OperationFailure +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _convert_exception, _OpMsg +from pymongo.monitoring import _is_speculative_authenticate +from pymongo.network_layer import receive_message, sendall + +if TYPE_CHECKING: + from bson import CodecOptions + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.monitoring import _EventListeners + from pymongo.pool_options import PoolOptions + from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.synchronous.client_session import ClientSession + from pymongo.synchronous.mongo_client import MongoClient + from pymongo.synchronous.pool import Connection + from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.write_concern import WriteConcern + +_IS_SYNC = True + + +class _BulkWriteContextProto(Protocol): + """Structural interface for bulk write context objects passed to :func:`run_bulk_write_command`.""" + + conn: Connection + db_name: str + session: Optional[ClientSession] + listeners: Optional[_EventListeners] + start_time: datetime.datetime + codec: CodecOptions[Any] + op_id: int + name: str + + +def _run_command( + conn: Connection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[MongoClient[Any]], + session: Optional[ClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + command_name: Optional[str] = None, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + pool_opts: Optional[PoolOptions] = None, + unacknowledged: bool = False, + speculative_hello: bool = False, + ensure_db: bool = False, + process_response: bool = True, + decrypt_reply: bool = True, + use_conn_transport: bool = False, + max_doc_size: int = 0, + more_to_come: bool = False, + set_conn_more_to_come: bool = True, + unpack_res: Optional[Callable[..., Any]] = None, + cursor_id: Optional[int] = None, +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: + """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. + + Private shared implementation. Use :func:`run_command`, + :func:`run_bulk_write_command`, or :func:`run_cursor_command` instead. + + It publishes the + ``STARTED``/``SUCCEEDED``/``FAILED`` command log and APM events, performs + the network round trip, gossips ``$clusterTime``, runs + ``client._process_response`` and ``_check_command_response``, and decrypts + the reply when auto-encryption is enabled. + + :param conn: The Connection to send on. + :param cmd: The command document, used for the ``STARTED`` log/APM event. + :param dbname: The database the command runs against. + :param request_id: The request id of the encoded message (``0`` when + ``more_to_come`` and no message is sent). + :param msg: The encoded bytes to send (ignored when ``more_to_come``). + :param client: The MongoClient, for ``$clusterTime`` gossip, logging, + and decryption. ``None`` disables those steps (e.g. during handshake). + :param session: The session to update from the response. + :param listeners: The event listeners, or ``None`` to disable APM. + :param address: The (host, port) of ``conn`` for APM events. + :param start: The ``datetime`` the operation began, for duration timing. + :param codec_options: The CodecOptions used to decode the reply. + :param user_fields: Response fields decoded with the codec's TypeDecoders. + :param orig: The command document published in the ``STARTED`` APM event; + defaults to ``cmd`` (differs only when the wire command was mutated, + e.g. with a read preference or after encryption). + :param op_id: The APM operation id; defaults to ``request_id``. + :param command_name: The command name for the ``SUCCEEDED``/``FAILED`` APM + events; defaults to the first key of ``cmd``. + :param check: Raise OperationFailure on a command error. + :param allowable_errors: Errors to ignore when ``check`` is True. + :param parse_write_concern_error: Parse the ``writeConcernError`` field. + :param pool_opts: PoolOptions forwarded to ``_check_command_response`` (the + cursor path uses this in place of ``allowable_errors``). + :param unacknowledged: True for an unacknowledged write: send only and fake + an ``{"ok": 1}`` reply. + :param speculative_hello: True if the command carried speculative auth, for + APM redaction. + :param ensure_db: Add ``$db`` to the published command if missing (cursor + path), after the ``STARTED`` log has been emitted. + :param process_response: Run ``client._process_response`` on the response + document before ``_check_command_response`` and APM/log events. + :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; + the bulk paths pass False (their commands are encrypted up front). + :param use_conn_transport: Send/receive via ``conn.send_message`` / + ``conn.receive_message`` instead of the raw ``sendall`` / + ``receive_message`` (network path). + :param max_doc_size: The largest document size, for ``conn.send_message``. + :param more_to_come: Receive only, without sending (exhaust ``getMore``). + :param set_conn_more_to_come: Store ``reply.more_to_come`` on ``conn`` (the + network/streaming-monitor path); the cursor path manages exhaust + separately and must leave ``conn.more_to_come`` untouched. + :param unpack_res: A callable decoding the wire response (cursor path); when + ``None`` the reply's own ``unpack_response`` is used. + :param cursor_id: The cursor id passed to ``unpack_res``. + """ + name = next(iter(cmd)) + if command_name is None: + command_name = name + if orig is None: + orig = cmd + publish = listeners is not None and listeners.enabled_for_commands + + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + if publish: + assert listeners is not None + assert address is not None + if ensure_db and "$db" not in orig: + orig["$db"] = dbname + listeners.publish_command_start( + orig, + dbname, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + ) + + reply: Optional[_OpMsg] + try: + if more_to_come: + reply = conn.receive_message(None) + elif unacknowledged: + if use_conn_transport: + conn._raise_if_not_writable() + conn.send_message(msg, max_doc_size) + else: + sendall(conn.conn.get_conn, msg) + # Unacknowledged, fake a successful command response. + reply = None + docs: list[dict[str, Any]] = [{"ok": 1}] + elif use_conn_transport: + if session is not None and session._starting_transaction: + session._transaction.set_in_progress() + conn.send_message(msg, max_doc_size) + reply = conn.receive_message(request_id) + else: + sendall(conn.conn.get_conn, msg) + reply = receive_message(conn, request_id) + + if reply is not None: + if set_conn_more_to_come: + conn.more_to_come = reply.more_to_come + if unpack_res is not None: + docs = unpack_res( + reply, + cursor_id, + codec_options, + user_fields=user_fields, + ) + else: + docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if process_response and client: + client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + pool_opts=pool_opts, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + command_name, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + + duration = datetime.datetime.now() - start + published_reply: _DocumentOut + published_reply = docs[0] + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=published_reply, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + speculative_authenticate="speculativeAuthenticate" in orig, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + published_reply, + command_name, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + if client and client._encrypter and reply and decrypt_reply: + decrypted = client._encrypter.decrypt(reply.raw_command_response()) + docs = cast( + "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) + ) + + return docs, reply, duration + + +def run_bulk_write_command( + bwc: _BulkWriteContextProto, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + *, + client: Optional[MongoClient[Any]], + orig: Optional[MutableMapping[str, Any]] = None, + max_doc_size: int = 0, + unacknowledged: bool = False, +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: + """Send a bulk write batch and return ``(docs, reply, duration)``. + + :param bwc: A bulk write context supplying the connection, session, listeners, etc. + :param max_doc_size: The largest document size; passed to ``conn.send_message``. + :param unacknowledged: When ``True``, send only and fake an ``{"ok": 1}`` reply. + """ + return _run_command( + bwc.conn, + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + orig=orig, + max_doc_size=max_doc_size, + unacknowledged=unacknowledged, + use_conn_transport=True, + decrypt_reply=False, + set_conn_more_to_come=False, + process_response=not unacknowledged, + ) + + +def run_cursor_command( + conn: Connection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[MongoClient[Any]], + session: Optional[ClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + command_name: str, + user_fields: Optional[Mapping[str, Any]] = None, + pool_opts: Optional[PoolOptions] = None, + max_doc_size: int = 0, + more_to_come: bool = False, + unpack_res: Optional[Callable[..., Any]] = None, + cursor_id: Optional[int] = None, +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: + """Run a cursor ``find``/``getMore`` operation over ``conn``. + + :param more_to_come: Receive only, without sending (exhaust ``getMore``). + :param unpack_res: A callable decoding the wire response. + :param cursor_id: The cursor id passed to ``unpack_res``. + + See :func:`_run_command` for the remaining parameters. + """ + return _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + command_name=command_name, + pool_opts=pool_opts, + ensure_db=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + more_to_come=more_to_come, + set_conn_more_to_come=False, + unpack_res=unpack_res, + cursor_id=cursor_id, + ) + + +def run_command( + conn: Connection, + dbname: str, + spec: MutableMapping[str, Any], + is_mongos: bool, # noqa: ARG001 + read_preference: Optional[_ServerMode], + codec_options: CodecOptions[_DocumentType], + session: Optional[ClientSession], + client: Optional[MongoClient[Any]], + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + address: Optional[_Address] = None, + listeners: Optional[_EventListeners] = None, + max_bson_size: Optional[int] = None, + read_concern: Optional[ReadConcern] = None, + parse_write_concern_error: bool = False, + collation: Optional[_CollationIn] = None, + compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, + unacknowledged: bool = False, + user_fields: Optional[Mapping[str, Any]] = None, + exhaust_allowed: bool = False, + write_concern: Optional[WriteConcern] = None, +) -> _DocumentType: + """Encode and execute a command over ``conn``, or raise socket.error. + + Applies read preference, read concern, collation, ``$clusterTime``, + auto-encryption, and CSOT to ``spec``, encodes it as an OP_MSG message, + then performs the network round trip and response processing. + + :param conn: a Connection instance + :param dbname: name of the database on which to run the command + :param spec: a command document as an ordered dict type, eg SON. + :param is_mongos: are we connected to a mongos? + :param read_preference: a read preference + :param codec_options: a CodecOptions instance + :param session: optional ClientSession instance. + :param client: optional MongoClient instance for updating $clusterTime. + :param check: raise OperationFailure if there are errors + :param allowable_errors: errors to ignore if `check` is True + :param address: the (host, port) of `conn` + :param listeners: An instance of :class:`~pymongo.monitoring.EventListeners` + :param max_bson_size: The maximum encoded bson size for this server + :param read_concern: The read concern for this command. + :param parse_write_concern_error: Whether to parse the ``writeConcernError`` + field in the command response. + :param collation: The collation for this command. + :param compression_ctx: optional compression Context. + :param unacknowledged: True if this is an unacknowledged command. + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed. + """ + name = next(iter(spec)) + speculative_hello = False + + # Publish the original command document, perhaps with lsid and $clusterTime. + orig = spec + if read_concern and not (session and session.in_transaction): + if read_concern.level: + spec["readConcern"] = read_concern.document + if session: + session._update_read_concern(spec, conn) + if collation is not None: + spec["collation"] = collation + + publish = listeners is not None and listeners.enabled_for_commands + start = datetime.datetime.now() + if publish: + speculative_hello = _is_speculative_authenticate(name, spec) + + if compression_ctx and name.lower() in _NO_COMPRESSION: + compression_ctx = None + + if client and client._encrypter and not client._encrypter._bypass_auto_encryption: + spec = orig = client._encrypter.encrypt(dbname, spec, codec_options) + + # Support CSOT + if client: + conn.apply_timeout(client, spec) + _csot.apply_write_concern(spec, write_concern) + + flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 + flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 + request_id, msg, size, max_doc_size = message._op_msg( + flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx + ) + # If this is an unacknowledged write then make sure the encoded doc(s) + # are small enough, otherwise rely on the server to return an error. + if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: + message._raise_document_too_large(name, size, max_bson_size) + + if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: + message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) + docs, _, _ = _run_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + speculative_hello=speculative_hello, + unacknowledged=unacknowledged, + process_response=not unacknowledged, + decrypt_reply=not unacknowledged, + ) + return docs[0] # type: ignore[return-value] diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py deleted file mode 100644 index e7bde29eb2..0000000000 --- a/pymongo/synchronous/network.py +++ /dev/null @@ -1,286 +0,0 @@ -# Copyright 2015-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Internal network layer helper methods.""" - -from __future__ import annotations - -import datetime -import logging -from collections.abc import Mapping, MutableMapping, Sequence -from typing import ( - TYPE_CHECKING, - Any, - Optional, - Union, - cast, -) - -from bson import _decode_all_selective -from pymongo import _csot, helpers_shared, message -from pymongo.compression_support import _NO_COMPRESSION -from pymongo.errors import ( - NotPrimaryError, - OperationFailure, -) -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.message import _OpMsg -from pymongo.monitoring import _is_speculative_authenticate -from pymongo.network_layer import ( - receive_message, - sendall, -) - -if TYPE_CHECKING: - from bson import CodecOptions - from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext - from pymongo.monitoring import _EventListeners - from pymongo.read_concern import ReadConcern - from pymongo.read_preferences import _ServerMode - from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.mongo_client import MongoClient - from pymongo.synchronous.pool import Connection - from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType - from pymongo.write_concern import WriteConcern - -_IS_SYNC = True - - -def command( - conn: Connection, - dbname: str, - spec: MutableMapping[str, Any], - is_mongos: bool, # noqa: ARG001 - read_preference: Optional[_ServerMode], - codec_options: CodecOptions[_DocumentType], - session: Optional[ClientSession], - client: Optional[MongoClient[Any]], - check: bool = True, - allowable_errors: Optional[Sequence[Union[str, int]]] = None, - address: Optional[_Address] = None, - listeners: Optional[_EventListeners] = None, - max_bson_size: Optional[int] = None, - read_concern: Optional[ReadConcern] = None, - parse_write_concern_error: bool = False, - collation: Optional[_CollationIn] = None, - compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, - unacknowledged: bool = False, - user_fields: Optional[Mapping[str, Any]] = None, - exhaust_allowed: bool = False, - write_concern: Optional[WriteConcern] = None, -) -> _DocumentType: - """Execute a command over the socket, or raise socket.error. - - :param conn: a Connection instance - :param dbname: name of the database on which to run the command - :param spec: a command document as an ordered dict type, eg SON. - :param is_mongos: are we connected to a mongos? - :param read_preference: a read preference - :param codec_options: a CodecOptions instance - :param session: optional ClientSession instance. - :param client: optional MongoClient instance for updating $clusterTime. - :param check: raise OperationFailure if there are errors - :param allowable_errors: errors to ignore if `check` is True - :param address: the (host, port) of `conn` - :param listeners: An instance of :class:`~pymongo.monitoring.EventListeners` - :param max_bson_size: The maximum encoded bson size for this server - :param read_concern: The read concern for this command. - :param parse_write_concern_error: Whether to parse the ``writeConcernError`` - field in the command response. - :param collation: The collation for this command. - :param compression_ctx: optional compression Context. - :param unacknowledged: True if this is an unacknowledged command. - :param user_fields: Response fields that should be decoded - using the TypeDecoders from codec_options, passed to - bson._decode_all_selective. - :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed. - """ - name = next(iter(spec)) - speculative_hello = False - - # Publish the original command document, perhaps with lsid and $clusterTime. - orig = spec - if read_concern and not (session and session.in_transaction): - if read_concern.level: - spec["readConcern"] = read_concern.document - if session: - session._update_read_concern(spec, conn) - if collation is not None: - spec["collation"] = collation - - publish = listeners is not None and listeners.enabled_for_commands - start = datetime.datetime.now() - if publish: - speculative_hello = _is_speculative_authenticate(name, spec) - - if compression_ctx and name.lower() in _NO_COMPRESSION: - compression_ctx = None - - if client and client._encrypter and not client._encrypter._bypass_auto_encryption: - spec = orig = client._encrypter.encrypt(dbname, spec, codec_options) - - # Support CSOT - if client: - conn.apply_timeout(client, spec) - _csot.apply_write_concern(spec, write_concern) - - flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 - flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 - request_id, msg, size, max_doc_size = message._op_msg( - flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx - ) - # If this is an unacknowledged write then make sure the encoded doc(s) - # are small enough, otherwise rely on the server to return an error. - if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: - message._raise_document_too_large(name, size, max_bson_size) - - if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: - message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=spec, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - sendall(conn.conn.get_conn, msg) - if unacknowledged: - # Unacknowledged, fake a successful command response. - reply = None - response_doc: _DocumentOut = {"ok": 1} - else: - reply = receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields - ) - - response_doc = unpacked_docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - ) - except Exception as exc: - duration = datetime.datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = message._convert_exception(exc) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = datetime.datetime.now() - start - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=response_doc, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - speculative_authenticate="speculativeAuthenticate" in orig, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - response_doc, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - speculative_hello=speculative_hello, - database_name=dbname, - ) - - if client and client._encrypter and reply: - decrypted = client._encrypter.decrypt(reply.raw_command_response()) - response_doc = cast( - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] - ) - - return response_doc # type: ignore[return-value] diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 2253c52419..7da71b5803 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -85,8 +85,8 @@ from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker from pymongo.synchronous.client_session import _validate_session_write_concern +from pymongo.synchronous.command_runner import run_command from pymongo.synchronous.helpers import _handle_reauth -from pymongo.synchronous.network import command if TYPE_CHECKING: from bson import CodecOptions @@ -390,11 +390,12 @@ def command( self.send_cluster_time(spec, session, client) listeners = self.listeners if publish_events else None unacknowledged = bool(write_concern and not write_concern.acknowledged) - self._raise_if_not_writable(unacknowledged) + if unacknowledged: + self._raise_if_not_writable() try: if session is not None and session._starting_transaction: session._transaction.set_in_progress() - return command( + return run_command( self, dbname, spec, @@ -451,43 +452,11 @@ def receive_message(self, request_id: Optional[int]) -> _OpMsg: except BaseException as error: self._raise_connection_failure(error) - def _raise_if_not_writable(self, unacknowledged: bool) -> None: - """Raise NotPrimaryError on unacknowledged write if this socket is not - writable. - """ - if unacknowledged and not self.is_writable: - # Write won't succeed, bail as if we'd received a not primary error. + def _raise_if_not_writable(self) -> None: + """Raise NotPrimaryError if this connection is not writable.""" + if not self.is_writable: raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) - def unack_write(self, msg: bytes, max_doc_size: int) -> None: - """Send unack OP_MSG. - - Can raise ConnectionFailure or InvalidDocument. - - :param msg: bytes, an OP_MSG message. - :param max_doc_size: size in bytes of the largest document in `msg`. - """ - self._raise_if_not_writable(True) - self.send_message(msg, max_doc_size) - - def write_command( - self, request_id: int, msg: bytes, codec_options: CodecOptions[Mapping[str, Any]] - ) -> dict[str, Any]: - """Send "insert" etc. command, returning response as a dict. - - Can raise ConnectionFailure or OperationFailure. - - :param request_id: an int. - :param msg: bytes, the command message. - """ - self.send_message(msg, 0) - reply = self.receive_message(request_id) - result = reply.command_response(codec_options) - - # Raises NotPrimaryError or OperationFailure. - helpers_shared._check_command_response(result, self.max_wire_version) - return result - def authenticate(self, reauthenticate: bool = False) -> None: """Authenticate to the server if needed. diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 46db33dbce..5e13a95648 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -27,18 +27,14 @@ Union, ) -from bson import _decode_all_selective -from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers_shared import _check_command_response from pymongo.logger import ( - _COMMAND_LOGGER, _SDAM_LOGGER, - _CommandStatusMessage, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query +from pymongo.message import _GetMore, _OpMsg, _Query from pymongo.response import PinnedResponse, Response +from pymongo.synchronous.command_runner import run_cursor_command from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: @@ -159,7 +155,6 @@ def run_operation( :param client: A MongoClient instance. """ assert listeners is not None - publish = listeners.enabled_for_commands start = datetime.now() use_cmd = operation.use_command(conn) @@ -167,144 +162,35 @@ def run_operation( cmd, dbn = self.operation_to_command(operation, conn, use_cmd) if more_to_come: request_id = 0 + data = b"" + max_doc_size = 0 else: message = operation.get_message(read_preference, conn, use_cmd) request_id, data, max_doc_size = self._split_message(message) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - - if publish: - if "$db" not in cmd: - cmd["$db"] = dbn - assert listeners is not None - listeners.publish_command_start( - cmd, - dbn, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - if more_to_come: - reply = conn.receive_message(None) - else: - if operation.session is not None and operation.session._starting_transaction: - operation.session._transaction.set_in_progress() - conn.send_message(data, max_doc_size) - reply = conn.receive_message(request_id) - - # Unpack and check for command errors. - if use_cmd: - user_fields = _CURSOR_DOC_FIELDS - legacy_response = False - else: - user_fields = None - legacy_response = True - docs = unpack_res( - reply, - operation.cursor_id, - operation.codec_options, - legacy_response=legacy_response, - user_fields=user_fields, - ) - if use_cmd: - first = docs[0] - operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] - _check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type] - except Exception as exc: - duration = datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - listeners.publish_command_failure( - duration, - failure, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - raise - duration = datetime.now() - start - # Must publish in find / getMore / explain command response - # format. - res = docs[0] - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=res, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - listeners.publish_command_success( - duration, - res, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - - # Decrypt response. - client = operation.client # type: ignore[assignment] - if client and client._encrypter: - if use_cmd: - decrypted = client._encrypter.decrypt(reply.raw_command_response()) - docs = _decode_all_selective(decrypted, operation.codec_options, user_fields) + user_fields = _CURSOR_DOC_FIELDS if use_cmd else None + + docs, reply, duration = run_cursor_command( + conn, + cmd, + dbn, + request_id, + data, + client=client, + session=operation.session, # type: ignore[arg-type] + listeners=listeners, + address=conn.address, + start=start, + codec_options=operation.codec_options, + user_fields=user_fields, + command_name=operation.name, + pool_opts=conn.opts, + max_doc_size=max_doc_size, + more_to_come=bool(more_to_come), + unpack_res=unpack_res, + cursor_id=operation.cursor_id, + ) + assert reply is not None response: Response @@ -326,7 +212,7 @@ def run_operation( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs, + docs=docs, # type: ignore[arg-type] more_to_come=more_to_come, ) else: @@ -336,7 +222,7 @@ def run_operation( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs, + docs=docs, # type: ignore[arg-type] ) return response diff --git a/test/test_message.py b/test/test_message.py index 64aafb09d0..e3094d1e27 100644 --- a/test/test_message.py +++ b/test/test_message.py @@ -23,15 +23,12 @@ sys.path[0:0] = [""] -from test import unittest - from bson import CodecOptions, encode from pymongo.compression_support import ZlibContext, _have_zlib from pymongo.errors import DocumentTooLarge, OperationFailure from pymongo.message import ( _convert_client_bulk_exception, _convert_exception, - _convert_write_result, _gen_find_command, _gen_get_more_command, _maybe_add_read_preference, @@ -40,6 +37,7 @@ ) from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference, SecondaryPreferred +from test import unittest _OPTS = CodecOptions() @@ -100,54 +98,6 @@ def test_client_bulk_exception_includes_code(self): self.assertEqual(doc["code"], 11000) self.assertEqual(doc["errtype"], "OperationFailure") - # _convert_write_result - # In the update command spec, `q` is the query/filter and `u` is the update document. - - def test_insert_basic(self): - cmd = {"documents": [{"_id": 1}, {"_id": 2}]} - result = _convert_write_result("insert", cmd, {"n": 0}) - self.assertEqual(result["ok"], 1) - self.assertEqual(result["n"], 2) - - def test_update_basic(self): - cmd = {"updates": [{"q": {}, "u": {"$set": {"x": 1}}}]} - result = _convert_write_result("update", cmd, {"n": 1, "updatedExisting": True}) - self.assertEqual(result["ok"], 1) - self.assertNotIn("upserted", result) - - def test_update_with_upserted_id(self): - cmd = {"updates": [{"q": {}, "u": {"_id": 42}}]} - result = _convert_write_result("update", cmd, {"n": 1, "upserted": 42}) - self.assertIn("upserted", result) - self.assertEqual(result["upserted"][0]["_id"], 42) - - def test_delete_basic(self): - cmd = {"deletes": [{"q": {}, "limit": 1}]} - result = _convert_write_result("delete", cmd, {"n": 1}) - self.assertEqual(result["ok"], 1) - self.assertEqual(result["n"], 1) - - def test_write_error(self): - cmd = {"documents": [{"_id": 1}]} - gle = {"n": 0, "err": "duplicate key error", "code": 11000} - result = _convert_write_result("insert", cmd, gle) - self.assertIn("writeErrors", result) - self.assertEqual(result["writeErrors"][0]["code"], 11000) - - def test_write_concern_timeout(self): - cmd = {"documents": [{"_id": 1}]} - gle = {"n": 1, "errmsg": "timeout", "wtimeout": True} - result = _convert_write_result("insert", cmd, gle) - self.assertIn("writeConcernError", result) - self.assertEqual(result["writeConcernError"]["code"], 64) - - def test_write_error_with_err_info(self): - # Covers the `if "errInfo" in result:` branch, which test_write_error does not enter. - cmd = {"documents": [{"_id": 1}]} - gle = {"n": 0, "err": "err", "code": 123, "errInfo": {"detail": "x"}} - result = _convert_write_result("insert", cmd, gle) - self.assertIn("errInfo", result["writeErrors"][0]) - # _op_msg def test_op_msg_max_doc_size_zero_without_docs(self):