Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions test/asynchronous/test_async_network_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Async-only unit tests for network_layer.py."""

from __future__ import annotations

import asyncio
import struct
import sys
from unittest.mock import AsyncMock, MagicMock, patch

sys.path[0:0] = [""]

from test.asynchronous import AsyncUnitTest, unittest

from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.errors import ProtocolError
from pymongo.network_layer import PyMongoProtocol, _async_socket_receive


async def _make_protocol(timeout=None):
protocol = PyMongoProtocol(timeout=timeout)
mock_transport = MagicMock()
mock_transport.is_closing.return_value = False
protocol.transport = mock_transport
return protocol


def _make_header(length, request_id, response_to, op_code):
return struct.pack("<iiii", length, request_id, response_to, op_code)


class TestPyMongoProtocol(AsyncUnitTest):
async def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
protocol = await _make_protocol()
protocol._max_message_size = max_size
protocol._header = memoryview(bytearray(header_bytes))
return protocol

async def test_normal_op_msg(self):
header = _make_header(length=32, request_id=1, response_to=99, op_code=2013)
protocol = await self._make_proto_with_header(header)
body_len, op_code, response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 16)
self.assertEqual(op_code, 2013)
self.assertEqual(response_to, 99)
self.assertFalse(expecting_compression)

async def test_op_compressed(self):
# OP_COMPRESSED=2012; process_header strips the 9-byte compression sub-header
# (op code + uncompressed size + compressor id), then the 16-byte standard header.
# length=35 → after compression sub-header: 26 → body: 10
header = _make_header(length=35, request_id=1, response_to=0, op_code=2012)
protocol = await self._make_proto_with_header(header)
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 10)
self.assertEqual(op_code, 2012)
self.assertTrue(expecting_compression)

async def test_op_compressed_length_too_small_raises(self):
header = _make_header(length=25, request_id=1, response_to=0, op_code=2012)
protocol = await self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()

async def test_non_compressed_length_too_small_raises(self):
header = _make_header(length=16, request_id=1, response_to=0, op_code=2013)
protocol = await self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()

async def test_length_exceeds_max_raises(self):
header = _make_header(
length=MAX_MESSAGE_SIZE + 1, request_id=1, response_to=0, op_code=2013
)
protocol = await self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()

async def test_op_reply_op_code(self):
header = _make_header(length=20, request_id=0, response_to=0, op_code=1)
protocol = await self._make_proto_with_header(header)
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 4)
self.assertEqual(op_code, 1)
self.assertFalse(expecting_compression)

async def test_compression_header_snappy_compressor_id(self):
protocol = await _make_protocol()
# <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
data = struct.pack("<iiB", 2013, 0, 1)
protocol._compression_header = memoryview(bytearray(data))
op_code, compressor_id = protocol.process_compression_header()
self.assertEqual(op_code, 2013)
self.assertEqual(compressor_id, 1)

async def test_compression_header_zlib_compressor_id(self):
protocol = await _make_protocol()
data = struct.pack("<iiB", 2013, 0, 2)
protocol._compression_header = memoryview(bytearray(data))
_, compressor_id = protocol.process_compression_header()
self.assertEqual(compressor_id, 2)

async def test_message_complete_resolves_pending_future(self):
protocol = await _make_protocol()
protocol._expecting_header = False
protocol._expecting_compression = False
protocol._message_size = 10
protocol._message = memoryview(bytearray(10))
protocol._message_index = 0
protocol._op_code = 2013
protocol._compressor_id = None
protocol._response_to = 42

future = asyncio.get_running_loop().create_future()
protocol._pending_messages.append(future)

protocol.buffer_updated(10)
self.assertTrue(future.done())
op_code, compressor_id, response_to, _ = future.result()
self.assertEqual(op_code, 2013)
self.assertIsNone(compressor_id)
self.assertEqual(response_to, 42)

async def test_close_aborts_transport(self):
protocol = await _make_protocol()
protocol.close()
self.assertTrue(protocol.transport.abort.called)

async def test_connection_lost_twice_does_not_raise(self):
protocol = await _make_protocol()
protocol.connection_lost(None)
protocol.connection_lost(None)

async def test_close_with_exception_propagates_to_pending(self):
protocol = await _make_protocol()
future = asyncio.get_running_loop().create_future()
protocol._pending_messages.append(future)
exc = OSError("connection reset")
protocol.close(exc)
with self.assertRaisesRegex(OSError, "connection reset"):
await future


class TestAsyncSocketReceive(AsyncUnitTest):
async def test_reads_data_in_multiple_chunks(self):
# Covers the loop in _async_socket_receive that accumulates short reads
# until the requested length has been received.
data = b"abcdefgh"
length = len(data)
chunk1, chunk2 = data[:4], data[4:]
mock_socket = MagicMock()
loop = asyncio.get_running_loop()
calls = 0

async def fake_recv_into(sock, buf):
nonlocal calls
if calls == 0:
buf[: len(chunk1)] = chunk1
calls += 1
return len(chunk1)
buf[: len(chunk2)] = chunk2
calls += 1
return len(chunk2)

with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
result = await _async_socket_receive(mock_socket, length, loop)
self.assertEqual(bytes(result), data)
self.assertEqual(calls, 2)

async def test_raises_on_connection_closed(self):
# Covers the explicit `raise OSError("connection closed")` branch when
# sock_recv_into returns 0.
mock_socket = MagicMock()
loop = asyncio.get_running_loop()

async def fake_recv_into(sock, buf):
return 0

with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
with self.assertRaisesRegex(OSError, "connection closed"):
await _async_socket_receive(mock_socket, 10, loop)


if __name__ == "__main__":
unittest.main()
153 changes: 153 additions & 0 deletions test/test_network_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sync-only unit tests for network_layer.py.

These cover ``receive_message`` and ``receive_data``, which only exist on the
synchronous receive path (the async path uses ``PyMongoProtocol`` instead).
The async-only tests live in ``test/asynchronous/test_async_network_layer.py``.
"""

from __future__ import annotations

import struct
import sys
from unittest.mock import MagicMock, patch

sys.path[0:0] = [""]

from test import UnitTest, unittest

from pymongo import network_layer
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.errors import ProtocolError


def _make_header(length, request_id, response_to, op_code):
return struct.pack("<iiii", length, request_id, response_to, op_code)


def _make_compression_header(op_code, uncompressed_size, compressor_id):
return struct.pack("<iiB", op_code, uncompressed_size, compressor_id)


def _make_conn():
conn = MagicMock()
conn.conn.gettimeout.return_value = None
# PyPy calls wait_for_read() before recv_into(), which checks fileno() == -1
# as an early-exit. Without this, sock.fileno() returns a MagicMock and the
# subsequent sock.pending() > 0 comparison raises TypeError on PyPy.
conn.conn.sock.fileno.return_value = -1
return conn


class TestReceiveMessage(UnitTest):
def _patch_receive_data(self, *chunks):
"""Make receive_data return the given byte strings on successive calls."""
mock = patch.object(network_layer, "receive_data", side_effect=list(chunks))
self.addCleanup(mock.stop)
return mock.start()

def test_request_id_mismatch_raises(self):
self._patch_receive_data(
_make_header(length=32, request_id=0, response_to=99, op_code=2013)
)
with self.assertRaises(ProtocolError):
network_layer.receive_message(_make_conn(), request_id=1)

def test_length_too_small_raises(self):
self._patch_receive_data(_make_header(length=16, request_id=0, response_to=0, op_code=2013))
with self.assertRaisesRegex(ProtocolError, "not longer than standard message header"):
network_layer.receive_message(_make_conn(), request_id=None)

def test_length_exceeds_max_raises(self):
self._patch_receive_data(
_make_header(length=MAX_MESSAGE_SIZE + 1, request_id=0, response_to=0, op_code=2013)
)
with self.assertRaisesRegex(ProtocolError, "larger than server max"):
network_layer.receive_message(_make_conn(), request_id=None)

def test_normal_op_msg_unpacks(self):
body = b"x" * 16
self._patch_receive_data(
_make_header(length=32, request_id=0, response_to=0, op_code=2013), body
)
unpack = MagicMock(return_value="REPLY")
with patch.object(network_layer, "_UNPACK_REPLY", {2013: unpack}):
result = network_layer.receive_message(_make_conn(), request_id=None)
unpack.assert_called_once_with(body)
self.assertEqual(result, "REPLY")

def test_op_compressed_decompresses(self):
# length=35 -> body length = 35 - 25 = 10 (header 16 + compression sub-header 9).
compressed_body = b"y" * 10
self._patch_receive_data(
_make_header(length=35, request_id=0, response_to=0, op_code=2012),
_make_compression_header(op_code=2013, uncompressed_size=0, compressor_id=1),
compressed_body,
)
unpack = MagicMock(return_value="REPLY")
with (
patch.object(network_layer, "decompress", return_value=b"decompressed") as decompress,
patch.object(network_layer, "_UNPACK_REPLY", {2013: unpack}),
):
result = network_layer.receive_message(_make_conn(), request_id=None)
decompress.assert_called_once_with(compressed_body, 1)
unpack.assert_called_once_with(b"decompressed")
self.assertEqual(result, "REPLY")

def test_unknown_opcode_raises(self):
self._patch_receive_data(
_make_header(length=20, request_id=0, response_to=0, op_code=9999), b"data"
)
with patch.object(network_layer, "_UNPACK_REPLY", {2013: MagicMock()}):
with self.assertRaises(ProtocolError):
network_layer.receive_message(_make_conn(), request_id=None)


class TestReceiveData(UnitTest):
def test_reads_data_in_multiple_chunks(self):
# Covers the loop in receive_data that accumulates short reads until the
# requested length has been received.
data = b"abcdefgh"
chunk1, chunk2 = data[:4], data[4:]
conn = _make_conn()
calls = 0

def fake_recv_into(buf):
nonlocal calls
if calls == 0:
buf[: len(chunk1)] = chunk1
calls += 1
return len(chunk1)
buf[: len(chunk2)] = chunk2
calls += 1
return len(chunk2)

conn.conn.recv_into.side_effect = fake_recv_into
result = network_layer.receive_data(conn, len(data), deadline=None)
self.assertEqual(bytes(result), data)
self.assertEqual(calls, 2)

def test_raises_on_connection_closed(self):
# Covers the explicit `raise OSError("connection closed")` branch when
# recv_into returns 0.
conn = _make_conn()
conn.conn.recv_into.return_value = 0
with self.assertRaisesRegex(OSError, "connection closed"):
network_layer.receive_data(conn, 10, deadline=None)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions tools/synchro.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def async_only_test(f: str) -> bool:
"test_async_loop_safety.py",
"test_async_contextvars_reset.py",
"test_async_loop_unblocked.py",
"test_async_network_layer.py",
]


Expand Down
Loading