diff --git a/test/asynchronous/test_async_network_layer.py b/test/asynchronous/test_async_network_layer.py new file mode 100644 index 0000000000..15c4a6cd4c --- /dev/null +++ b/test/asynchronous/test_async_network_layer.py @@ -0,0 +1,198 @@ +# Copyright 2026-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async-only unit tests for network_layer.py.""" + +from __future__ import annotations + +import asyncio +import struct +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncUnitTest, unittest + +from pymongo.common import MAX_MESSAGE_SIZE +from pymongo.errors import ProtocolError +from pymongo.network_layer import PyMongoProtocol, _async_socket_receive + + +async def _make_protocol(timeout=None): + protocol = PyMongoProtocol(timeout=timeout) + mock_transport = MagicMock() + mock_transport.is_closing.return_value = False + protocol.transport = mock_transport + return protocol + + +def _make_header(length, request_id, response_to, op_code): + return struct.pack(" 0 comparison raises TypeError on PyPy. + conn.conn.sock.fileno.return_value = -1 + return conn + + +class TestReceiveMessage(UnitTest): + def _patch_receive_data(self, *chunks): + """Make receive_data return the given byte strings on successive calls.""" + mock = patch.object(network_layer, "receive_data", side_effect=list(chunks)) + self.addCleanup(mock.stop) + return mock.start() + + def test_request_id_mismatch_raises(self): + self._patch_receive_data( + _make_header(length=32, request_id=0, response_to=99, op_code=2013) + ) + with self.assertRaises(ProtocolError): + network_layer.receive_message(_make_conn(), request_id=1) + + def test_length_too_small_raises(self): + self._patch_receive_data(_make_header(length=16, request_id=0, response_to=0, op_code=2013)) + with self.assertRaisesRegex(ProtocolError, "not longer than standard message header"): + network_layer.receive_message(_make_conn(), request_id=None) + + def test_length_exceeds_max_raises(self): + self._patch_receive_data( + _make_header(length=MAX_MESSAGE_SIZE + 1, request_id=0, response_to=0, op_code=2013) + ) + with self.assertRaisesRegex(ProtocolError, "larger than server max"): + network_layer.receive_message(_make_conn(), request_id=None) + + def test_normal_op_msg_unpacks(self): + body = b"x" * 16 + self._patch_receive_data( + _make_header(length=32, request_id=0, response_to=0, op_code=2013), body + ) + unpack = MagicMock(return_value="REPLY") + with patch.object(network_layer, "_UNPACK_REPLY", {2013: unpack}): + result = network_layer.receive_message(_make_conn(), request_id=None) + unpack.assert_called_once_with(body) + self.assertEqual(result, "REPLY") + + def test_op_compressed_decompresses(self): + # length=35 -> body length = 35 - 25 = 10 (header 16 + compression sub-header 9). + compressed_body = b"y" * 10 + self._patch_receive_data( + _make_header(length=35, request_id=0, response_to=0, op_code=2012), + _make_compression_header(op_code=2013, uncompressed_size=0, compressor_id=1), + compressed_body, + ) + unpack = MagicMock(return_value="REPLY") + with ( + patch.object(network_layer, "decompress", return_value=b"decompressed") as decompress, + patch.object(network_layer, "_UNPACK_REPLY", {2013: unpack}), + ): + result = network_layer.receive_message(_make_conn(), request_id=None) + decompress.assert_called_once_with(compressed_body, 1) + unpack.assert_called_once_with(b"decompressed") + self.assertEqual(result, "REPLY") + + def test_unknown_opcode_raises(self): + self._patch_receive_data( + _make_header(length=20, request_id=0, response_to=0, op_code=9999), b"data" + ) + with patch.object(network_layer, "_UNPACK_REPLY", {2013: MagicMock()}): + with self.assertRaises(ProtocolError): + network_layer.receive_message(_make_conn(), request_id=None) + + +class TestReceiveData(UnitTest): + def test_reads_data_in_multiple_chunks(self): + # Covers the loop in receive_data that accumulates short reads until the + # requested length has been received. + data = b"abcdefgh" + chunk1, chunk2 = data[:4], data[4:] + conn = _make_conn() + calls = 0 + + def fake_recv_into(buf): + nonlocal calls + if calls == 0: + buf[: len(chunk1)] = chunk1 + calls += 1 + return len(chunk1) + buf[: len(chunk2)] = chunk2 + calls += 1 + return len(chunk2) + + conn.conn.recv_into.side_effect = fake_recv_into + result = network_layer.receive_data(conn, len(data), deadline=None) + self.assertEqual(bytes(result), data) + self.assertEqual(calls, 2) + + def test_raises_on_connection_closed(self): + # Covers the explicit `raise OSError("connection closed")` branch when + # recv_into returns 0. + conn = _make_conn() + conn.conn.recv_into.return_value = 0 + with self.assertRaisesRegex(OSError, "connection closed"): + network_layer.receive_data(conn, 10, deadline=None) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/synchro.py b/tools/synchro.py index 13635a054a..769ffe4f20 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -190,6 +190,7 @@ def async_only_test(f: str) -> bool: "test_async_loop_safety.py", "test_async_contextvars_reset.py", "test_async_loop_unblocked.py", + "test_async_network_layer.py", ]