From c2052a9978e8aa1b5c82e75210703b9b1ab11b15 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 2 Mar 2026 14:53:07 -0800 Subject: [PATCH 1/3] PYTHON-5114 Test suite reduce killAllSessions calls --- test/asynchronous/unified_format.py | 14 +++++++++----- test/asynchronous/utils_spec_runner.py | 9 ++++++--- test/unified_format.py | 14 +++++++++----- test/utils_spec_runner.py | 9 ++++++--- 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 6ce8f852cf..a4d64f8a28 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -1464,11 +1464,6 @@ async def verify_outcome(self, spec): self.assertListEqual(sorted_expected_documents, actual_documents) async def run_scenario(self, spec, uri=None): - # Kill all sessions before and after each test to prevent an open - # transaction (from a test failure) from blocking collection/database - # operations during test set up and tear down. - await self.kill_all_sessions() - # Handle flaky tests. flaky_tests = [ ("PYTHON-5170", ".*test_discovery_and_monitoring.*"), @@ -1504,6 +1499,15 @@ async def _run_scenario(self, spec, uri=None): if skip_reason is not None: raise unittest.SkipTest(f"{skip_reason}") + # Kill all sessions after each test with transactions prevent an open + # transaction (from a test failure) from blocking collection/database + # operations during test set up and tear down. + for op in spec["operations"]: + name = op["name"] + if name == "startTransaction" or name == "withTransaction": + self.addAsyncCleanup(self.kill_all_sessions) + break + # process createEntities self._uri = uri self.entity_map = EntityMapUtil(self) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 63e7e9e150..c27b4c1c23 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -621,11 +621,14 @@ async def setup_scenario(self, scenario_def): async def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) - # Kill all sessions before and after each test to prevent an open + # Kill all sessions after each test with transactions prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. - await self.kill_all_sessions() - self.addAsyncCleanup(self.kill_all_sessions) + for op in test["operations"]: + name = op["name"] + if name == "startTransaction" or name == "withTransaction": + self.addAsyncCleanup(self.kill_all_sessions) + break await self.setup_scenario(scenario_def) database_name = self.get_scenario_db_name(scenario_def) collection_name = self.get_scenario_coll_name(scenario_def) diff --git a/test/unified_format.py b/test/unified_format.py index 9aee287256..31ac178cc7 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1451,11 +1451,6 @@ def verify_outcome(self, spec): self.assertListEqual(sorted_expected_documents, actual_documents) def run_scenario(self, spec, uri=None): - # Kill all sessions before and after each test to prevent an open - # transaction (from a test failure) from blocking collection/database - # operations during test set up and tear down. - self.kill_all_sessions() - # Handle flaky tests. flaky_tests = [ ("PYTHON-5170", ".*test_discovery_and_monitoring.*"), @@ -1491,6 +1486,15 @@ def _run_scenario(self, spec, uri=None): if skip_reason is not None: raise unittest.SkipTest(f"{skip_reason}") + # Kill all sessions after each test with transactions prevent an open + # transaction (from a test failure) from blocking collection/database + # operations during test set up and tear down. + for op in spec["operations"]: + name = op["name"] + if name == "startTransaction" or name == "withTransaction": + self.addCleanup(self.kill_all_sessions) + break + # process createEntities self._uri = uri self.entity_map = EntityMapUtil(self) diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 9bf155e8f3..72788c4a1a 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -621,11 +621,14 @@ def setup_scenario(self, scenario_def): def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) - # Kill all sessions before and after each test to prevent an open + # Kill all sessions after each test with transactions prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. - self.kill_all_sessions() - self.addCleanup(self.kill_all_sessions) + for op in test["operations"]: + name = op["name"] + if name == "startTransaction" or name == "withTransaction": + self.addCleanup(self.kill_all_sessions) + break self.setup_scenario(scenario_def) database_name = self.get_scenario_db_name(scenario_def) collection_name = self.get_scenario_coll_name(scenario_def) From 8f238421e05f1bcd943e690f340c4c7c85cb2dd2 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 5 Mar 2026 10:06:40 -0800 Subject: [PATCH 2/3] PYTHON-5114 Fix comment --- test/asynchronous/unified_format.py | 2 +- test/asynchronous/utils_spec_runner.py | 2 +- test/unified_format.py | 2 +- test/utils_spec_runner.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index a4d64f8a28..1fb93e7b86 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -1499,7 +1499,7 @@ async def _run_scenario(self, spec, uri=None): if skip_reason is not None: raise unittest.SkipTest(f"{skip_reason}") - # Kill all sessions after each test with transactions prevent an open + # Kill all sessions after each test with transactions to prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. for op in spec["operations"]: diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index c27b4c1c23..f099eee12c 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -621,7 +621,7 @@ async def setup_scenario(self, scenario_def): async def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) - # Kill all sessions after each test with transactions prevent an open + # Kill all sessions after each test with transactions to prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. for op in test["operations"]: diff --git a/test/unified_format.py b/test/unified_format.py index 31ac178cc7..5516a7adf1 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1486,7 +1486,7 @@ def _run_scenario(self, spec, uri=None): if skip_reason is not None: raise unittest.SkipTest(f"{skip_reason}") - # Kill all sessions after each test with transactions prevent an open + # Kill all sessions after each test with transactions to prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. for op in spec["operations"]: diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 72788c4a1a..34e1c95ef2 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -621,7 +621,7 @@ def setup_scenario(self, scenario_def): def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) - # Kill all sessions after each test with transactions prevent an open + # Kill all sessions after each test with transactions to prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. for op in test["operations"]: From c0ac7321858b2fecc95c3cd84d86e4bd12928b18 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 5 Mar 2026 11:18:28 -0800 Subject: [PATCH 3/3] PYTHON-5114 Remove unused SpecRunner class --- test/asynchronous/utils_spec_runner.py | 634 +------------------------ test/utils_spec_runner.py | 632 +----------------------- 2 files changed, 8 insertions(+), 1258 deletions(-) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index f099eee12c..344fd97536 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -16,43 +16,14 @@ from __future__ import annotations import asyncio -import functools import os import time -import unittest -from collections import abc -from inspect import iscoroutinefunction -from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs +from test.asynchronous import async_client_context from test.asynchronous.helpers import ConcurrentRunner -from test.utils_shared import ( - CMAPListener, - CompareType, - EventListener, - OvertCommandListener, - ScenarioDict, - ServerAndTopologyEventListener, - camel_to_snake, - camel_to_snake_args, - parse_spec_options, - prepare_spec_arguments, -) -from typing import List - -from bson import ObjectId, decode, encode, json_util -from bson.binary import Binary -from bson.int64 import Int64 -from bson.son import SON -from gridfs import GridFSBucket -from gridfs.asynchronous.grid_file import AsyncGridFSBucket -from pymongo.asynchronous import client_session -from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.cursor import AsyncCursor -from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from test.utils_shared import ScenarioDict + +from bson import json_util from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock -from pymongo.read_concern import ReadConcern -from pymongo.read_preferences import ReadPreference -from pymongo.results import BulkWriteResult, _WriteResult -from pymongo.write_concern import WriteConcern _IS_SYNC = False @@ -219,600 +190,3 @@ def create_tests(self): self._create_tests() else: asyncio.run(self._create_tests()) - - -class AsyncSpecRunner(AsyncIntegrationTest): - mongos_clients: List - knobs: client_knobs - listener: EventListener - - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - self.mongos_clients = [] - - # Speed up the tests by decreasing the heartbeat frequency. - self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - self.knobs.enable() - self.targets = {} - self.listener = None # type: ignore - self.pool_listener = None - self.server_listener = None - self.maxDiff = None - - async def asyncTearDown(self) -> None: - self.knobs.disable() - - async def set_fail_point(self, command_args): - clients = self.mongos_clients if self.mongos_clients else [self.client] - for client in clients: - await self.configure_fail_point(client, command_args) - - async def targeted_fail_point(self, session, fail_point): - """Run the targetedFailPoint test operation. - - Enable the fail point on the session's pinned mongos. - """ - clients = {c.address: c for c in self.mongos_clients} - client = clients[session._pinned_address] - await self.configure_fail_point(client, fail_point) - self.addAsyncCleanup(self.set_fail_point, {"mode": "off"}) - - def assert_session_pinned(self, session): - """Run the assertSessionPinned test operation. - - Assert that the given session is pinned. - """ - self.assertIsNotNone(session._transaction.pinned_address) - - def assert_session_unpinned(self, session): - """Run the assertSessionUnpinned test operation. - - Assert that the given session is not pinned. - """ - self.assertIsNone(session._pinned_address) - self.assertIsNone(session._transaction.pinned_address) - - async def assert_collection_exists(self, database, collection): - """Run the assertCollectionExists test operation.""" - db = self.client[database] - self.assertIn(collection, await db.list_collection_names()) - - async def assert_collection_not_exists(self, database, collection): - """Run the assertCollectionNotExists test operation.""" - db = self.client[database] - self.assertNotIn(collection, await db.list_collection_names()) - - async def assert_index_exists(self, database, collection, index): - """Run the assertIndexExists test operation.""" - coll = self.client[database][collection] - self.assertIn(index, [doc["name"] async for doc in await coll.list_indexes()]) - - async def assert_index_not_exists(self, database, collection, index): - """Run the assertIndexNotExists test operation.""" - coll = self.client[database][collection] - self.assertNotIn(index, [doc["name"] async for doc in await coll.list_indexes()]) - - async def wait(self, ms): - """Run the "wait" test operation.""" - await asyncio.sleep(ms / 1000.0) - - def assertErrorLabelsContain(self, exc, expected_labels): - labels = [l for l in expected_labels if exc.has_error_label(l)] - self.assertEqual(labels, expected_labels) - - def assertErrorLabelsOmit(self, exc, omit_labels): - for label in omit_labels: - self.assertFalse( - exc.has_error_label(label), msg=f"error labels should not contain {label}" - ) - - async def kill_all_sessions(self): - clients = self.mongos_clients if self.mongos_clients else [self.client] - for client in clients: - try: - await client.admin.command("killAllSessions", []) - except (OperationFailure, AutoReconnect): - # "operation was interrupted" by killing the command's - # own session. - # On 8.0+ killAllSessions sometimes returns a network error. - pass - - def check_command_result(self, expected_result, result): - # Only compare the keys in the expected result. - filtered_result = {} - for key in expected_result: - try: - filtered_result[key] = result[key] - except KeyError: - pass - self.assertEqual(filtered_result, expected_result) - - # TODO: factor the following function with test_crud.py. - def check_result(self, expected_result, result): - if isinstance(result, _WriteResult): - for res in expected_result: - prop = camel_to_snake(res) - # SPEC-869: Only BulkWriteResult has upserted_count. - if prop == "upserted_count" and not isinstance(result, BulkWriteResult): - if result.upserted_id is not None: - upserted_count = 1 - else: - upserted_count = 0 - self.assertEqual(upserted_count, expected_result[res], prop) - elif prop == "inserted_ids": - # BulkWriteResult does not have inserted_ids. - if isinstance(result, BulkWriteResult): - self.assertEqual(len(expected_result[res]), result.inserted_count) - else: - # InsertManyResult may be compared to [id1] from the - # crud spec or {"0": id1} from the retryable write spec. - ids = expected_result[res] - if isinstance(ids, dict): - ids = [ids[str(i)] for i in range(len(ids))] - - self.assertEqual(ids, result.inserted_ids, prop) - elif prop == "upserted_ids": - # Convert indexes from strings to integers. - ids = expected_result[res] - expected_ids = {} - for str_index in ids: - expected_ids[int(str_index)] = ids[str_index] - self.assertEqual(expected_ids, result.upserted_ids, prop) - else: - self.assertEqual(getattr(result, prop), expected_result[res], prop) - - return True - else: - - def _helper(expected_result, result): - if isinstance(expected_result, abc.Mapping): - for i in expected_result.keys(): - self.assertEqual(expected_result[i], result[i]) - - elif isinstance(expected_result, list): - for i, k in zip(expected_result, result): - _helper(i, k) - else: - self.assertEqual(expected_result, result) - - _helper(expected_result, result) - return None - - def get_object_name(self, op): - """Allow subclasses to override handling of 'object' - - Transaction spec says 'object' is required. - """ - return op["object"] - - @staticmethod - def parse_options(opts): - return parse_spec_options(opts) - - async def run_operation(self, sessions, collection, operation): - original_collection = collection - name = camel_to_snake(operation["name"]) - if name == "run_command": - name = "command" - elif name == "download_by_name": - name = "open_download_stream_by_name" - elif name == "download": - name = "open_download_stream" - elif name == "map_reduce": - self.skipTest("PyMongo does not support mapReduce") - elif name == "count": - self.skipTest("PyMongo does not support count") - - database = collection.database - collection = database.get_collection(collection.name) - if "collectionOptions" in operation: - collection = collection.with_options( - **self.parse_options(operation["collectionOptions"]) - ) - - object_name = self.get_object_name(operation) - if object_name == "gridfsbucket": - # Only create the GridFSBucket when we need it (for the gridfs - # retryable reads tests). - obj = AsyncGridFSBucket(database, bucket_name=collection.name) - else: - objects = { - "client": database.client, - "database": database, - "collection": collection, - "testRunner": self, - } - objects.update(sessions) - obj = objects[object_name] - - # Combine arguments with options and handle special cases. - arguments = operation.get("arguments", {}) - arguments.update(arguments.pop("options", {})) - self.parse_options(arguments) - - cmd = getattr(obj, name) - - with_txn_callback = functools.partial( - self.run_operations, sessions, original_collection, in_with_transaction=True - ) - prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback) - - if name == "run_on_thread": - args = {"sessions": sessions, "collection": collection} - args.update(arguments) - arguments = args - - if not _IS_SYNC and iscoroutinefunction(cmd): - result = await cmd(**dict(arguments)) - else: - result = cmd(**dict(arguments)) - # Cleanup open change stream cursors. - if name == "watch": - self.addAsyncCleanup(result.close) - - if name == "aggregate": - if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: - # Read from the primary to ensure causal consistency. - out = collection.database.get_collection( - arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY - ) - return out.find() - if "download" in name: - result = Binary(result.read()) - - if isinstance(result, AsyncCursor) or isinstance(result, AsyncCommandCursor): - return await result.to_list() - - return result - - def allowable_errors(self, op): - """Allow encryption spec to override expected error classes.""" - return (PyMongoError,) - - async def _run_op(self, sessions, collection, op, in_with_transaction): - expected_result = op.get("result") - if expect_error(op): - with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context: - await self.run_operation(sessions, collection, op.copy()) - exc = context.exception - if expect_error_message(expected_result): - if isinstance(exc, BulkWriteError): - errmsg = str(exc.details).lower() - else: - errmsg = str(exc).lower() - self.assertIn(expected_result["errorContains"].lower(), errmsg) - if expect_error_code(expected_result): - self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName")) - if expect_error_labels_contain(expected_result): - self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"]) - if expect_error_labels_omit(expected_result): - self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"]) - if expect_timeout_error(expected_result): - self.assertIsInstance(exc, PyMongoError) - if not exc.timeout: - # Re-raise the exception for better diagnostics. - raise exc - - # Reraise the exception if we're in the with_transaction - # callback. - if in_with_transaction: - raise context.exception - else: - result = await self.run_operation(sessions, collection, op.copy()) - if "result" in op: - if op["name"] == "runCommand": - self.check_command_result(expected_result, result) - else: - self.check_result(expected_result, result) - - async def run_operations(self, sessions, collection, ops, in_with_transaction=False): - for op in ops: - await self._run_op(sessions, collection, op, in_with_transaction) - - # TODO: factor with test_command_monitoring.py - def check_events(self, test, listener, session_ids): - events = listener.started_events - if not len(test["expectations"]): - return - - # Give a nicer message when there are missing or extra events - cmds = decode_raw([event.command for event in events]) - self.assertEqual(len(events), len(test["expectations"]), cmds) - for i, expectation in enumerate(test["expectations"]): - event_type = next(iter(expectation)) - event = events[i] - - # The tests substitute 42 for any number other than 0. - if event.command_name == "getMore" and event.command["getMore"]: - event.command["getMore"] = Int64(42) - elif event.command_name == "killCursors": - event.command["cursors"] = [Int64(42)] - elif event.command_name == "update": - # TODO: remove this once PYTHON-1744 is done. - # Add upsert and multi fields back into expectations. - updates = expectation[event_type]["command"]["updates"] - for update in updates: - update.setdefault("upsert", False) - update.setdefault("multi", False) - - # Replace afterClusterTime: 42 with actual afterClusterTime. - expected_cmd = expectation[event_type]["command"] - expected_read_concern = expected_cmd.get("readConcern") - if expected_read_concern is not None: - time = expected_read_concern.get("afterClusterTime") - if time == 42: - actual_time = event.command.get("readConcern", {}).get("afterClusterTime") - if actual_time is not None: - expected_read_concern["afterClusterTime"] = actual_time - - recovery_token = expected_cmd.get("recoveryToken") - if recovery_token == 42: - expected_cmd["recoveryToken"] = CompareType(dict) - - # Replace lsid with a name like "session0" to match test. - if "lsid" in event.command: - for name, lsid in session_ids.items(): - if event.command["lsid"] == lsid: - event.command["lsid"] = name - break - - for attr, expected in expectation[event_type].items(): - actual = getattr(event, attr) - expected = wrap_types(expected) - if isinstance(expected, dict): - for key, val in expected.items(): - if val is None: - if key in actual: - self.fail(f"Unexpected key [{key}] in {actual!r}") - elif key not in actual: - self.fail(f"Expected key [{key}] in {actual!r}") - else: - self.assertEqual( - val, decode_raw(actual[key]), f"Key [{key}] in {actual}" - ) - else: - self.assertEqual(actual, expected) - - def maybe_skip_scenario(self, test): - if test.get("skipReason"): - self.skipTest(test.get("skipReason")) - - def get_scenario_db_name(self, scenario_def): - """Allow subclasses to override a test's database name.""" - return scenario_def["database_name"] - - def get_scenario_coll_name(self, scenario_def): - """Allow subclasses to override a test's collection name.""" - return scenario_def["collection_name"] - - def get_outcome_coll_name(self, outcome, collection): - """Allow subclasses to override outcome collection.""" - return collection.name - - async def run_test_ops(self, sessions, collection, test): - """Added to allow retryable writes spec to override a test's - operation. - """ - await self.run_operations(sessions, collection, test["operations"]) - - def parse_client_options(self, opts): - """Allow encryption spec to override a clientOptions parsing.""" - return opts - - async def setup_scenario(self, scenario_def): - """Allow specs to override a test's setup.""" - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - documents = scenario_def["data"] - - # Setup the collection with as few majority writes as possible. - db = async_client_context.client.get_database(db_name) - coll_exists = bool(await db.list_collection_names(filter={"name": coll_name})) - if coll_exists: - await db[coll_name].delete_many({}) - # Only use majority wc only on the final write. - wc = WriteConcern(w="majority") - if documents: - db.get_collection(coll_name, write_concern=wc).insert_many(documents) - elif not coll_exists: - # Ensure collection exists. - await db.create_collection(coll_name, write_concern=wc) - - async def run_scenario(self, scenario_def, test): - self.maybe_skip_scenario(test) - - # Kill all sessions after each test with transactions to prevent an open - # transaction (from a test failure) from blocking collection/database - # operations during test set up and tear down. - for op in test["operations"]: - name = op["name"] - if name == "startTransaction" or name == "withTransaction": - self.addAsyncCleanup(self.kill_all_sessions) - break - await self.setup_scenario(scenario_def) - database_name = self.get_scenario_db_name(scenario_def) - collection_name = self.get_scenario_coll_name(scenario_def) - # SPEC-1245 workaround StaleDbVersion on distinct - for c in self.mongos_clients: - await c[database_name][collection_name].distinct("x") - - # Configure the fail point before creating the client. - if "failPoint" in test: - fp = test["failPoint"] - await self.set_fail_point(fp) - self.addAsyncCleanup( - self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} - ) - - listener = OvertCommandListener() - pool_listener = CMAPListener() - server_listener = ServerAndTopologyEventListener() - # Create a new client, to avoid interference from pooled sessions. - client_options = self.parse_client_options(test["clientOptions"]) - use_multi_mongos = test["useMultipleMongoses"] - host = None - if use_multi_mongos: - if async_client_context.load_balancer: - host = async_client_context.MULTI_MONGOS_LB_URI - elif async_client_context.is_mongos: - host = async_client_context.mongos_seeds() - client = await self.async_rs_client( - h=host, event_listeners=[listener, pool_listener, server_listener], **client_options - ) - self.scenario_client = client - self.listener = listener - self.pool_listener = pool_listener - self.server_listener = server_listener - - # Create session0 and session1. - sessions = {} - session_ids = {} - for i in range(2): - # Don't attempt to create sessions if they are not supported by - # the running server version. - if not async_client_context.sessions_enabled: - break - session_name = "session%d" % i - opts = camel_to_snake_args(test["sessionOptions"][session_name]) - if "default_transaction_options" in opts: - txn_opts = self.parse_options(opts["default_transaction_options"]) - txn_opts = client_session.TransactionOptions(**txn_opts) - opts["default_transaction_options"] = txn_opts - - s = client.start_session(**dict(opts)) - - sessions[session_name] = s - # Store lsid so we can access it after end_session, in check_events. - session_ids[session_name] = s.session_id - - self.addAsyncCleanup(end_sessions, sessions) - - collection = client[database_name][collection_name] - await self.run_test_ops(sessions, collection, test) - - await end_sessions(sessions) - - self.check_events(test, listener, session_ids) - - # Disable fail points. - if "failPoint" in test: - fp = test["failPoint"] - await self.set_fail_point( - {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} - ) - - # Assert final state is expected. - outcome = test["outcome"] - expected_c = outcome.get("collection") - if expected_c is not None: - outcome_coll_name = self.get_outcome_coll_name(outcome, collection) - - # Read from the primary with local read concern to ensure causal - # consistency. - outcome_coll = async_client_context.client[collection.database.name].get_collection( - outcome_coll_name, - read_preference=ReadPreference.PRIMARY, - read_concern=ReadConcern("local"), - ) - actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list() - - # The expected data needs to be the left hand side here otherwise - # CompareType(Binary) doesn't work. - self.assertEqual(wrap_types(expected_c["data"]), actual_data) - - -def expect_any_error(op): - if isinstance(op, dict): - return op.get("error") - - return False - - -def expect_error_message(expected_result): - if isinstance(expected_result, dict): - return isinstance(expected_result["errorContains"], str) - - return False - - -def expect_error_code(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorCodeName"] - - return False - - -def expect_error_labels_contain(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorLabelsContain"] - - return False - - -def expect_error_labels_omit(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorLabelsOmit"] - - return False - - -def expect_timeout_error(expected_result): - if isinstance(expected_result, dict): - return expected_result["isTimeoutError"] - - return False - - -def expect_error(op): - expected_result = op.get("result") - return ( - expect_any_error(op) - or expect_error_message(expected_result) - or expect_error_code(expected_result) - or expect_error_labels_contain(expected_result) - or expect_error_labels_omit(expected_result) - or expect_timeout_error(expected_result) - ) - - -async def end_sessions(sessions): - for s in sessions.values(): - # Aborts the transaction if it's open. - await s.end_session() - - -def decode_raw(val): - """Decode RawBSONDocuments in the given container.""" - if isinstance(val, (list, abc.Mapping)): - return decode(encode({"v": val}))["v"] - return val - - -TYPES = { - "binData": Binary, - "long": Int64, - "int": int, - "string": str, - "objectId": ObjectId, - "object": dict, - "array": list, -} - - -def wrap_types(val): - """Support $$type assertion in command results.""" - if isinstance(val, list): - return [wrap_types(v) for v in val] - if isinstance(val, abc.Mapping): - typ = val.get("$$type") - if typ: - if isinstance(typ, str): - types = TYPES[typ] - else: - types = tuple(TYPES[t] for t in typ) - return CompareType(types) - d = {} - for key in val: - d[key] = wrap_types(val[key]) - return d - return val diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 34e1c95ef2..95e580cef9 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -16,43 +16,14 @@ from __future__ import annotations import asyncio -import functools import os import time -import unittest -from collections import abc -from inspect import iscoroutinefunction -from test import IntegrationTest, client_context, client_knobs +from test import client_context from test.helpers import ConcurrentRunner -from test.utils_shared import ( - CMAPListener, - CompareType, - EventListener, - OvertCommandListener, - ScenarioDict, - ServerAndTopologyEventListener, - camel_to_snake, - camel_to_snake_args, - parse_spec_options, - prepare_spec_arguments, -) -from typing import List - -from bson import ObjectId, decode, encode, json_util -from bson.binary import Binary -from bson.int64 import Int64 -from bson.son import SON -from gridfs import GridFSBucket -from gridfs.synchronous.grid_file import GridFSBucket -from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from test.utils_shared import ScenarioDict + +from bson import json_util from pymongo.lock import _cond_wait, _create_condition, _create_lock -from pymongo.read_concern import ReadConcern -from pymongo.read_preferences import ReadPreference -from pymongo.results import BulkWriteResult, _WriteResult -from pymongo.synchronous import client_session -from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.cursor import Cursor -from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -219,598 +190,3 @@ def create_tests(self): self._create_tests() else: asyncio.run(self._create_tests()) - - -class SpecRunner(IntegrationTest): - mongos_clients: List - knobs: client_knobs - listener: EventListener - - def setUp(self) -> None: - super().setUp() - self.mongos_clients = [] - - # Speed up the tests by decreasing the heartbeat frequency. - self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - self.knobs.enable() - self.targets = {} - self.listener = None # type: ignore - self.pool_listener = None - self.server_listener = None - self.maxDiff = None - - def tearDown(self) -> None: - self.knobs.disable() - - def set_fail_point(self, command_args): - clients = self.mongos_clients if self.mongos_clients else [self.client] - for client in clients: - self.configure_fail_point(client, command_args) - - def targeted_fail_point(self, session, fail_point): - """Run the targetedFailPoint test operation. - - Enable the fail point on the session's pinned mongos. - """ - clients = {c.address: c for c in self.mongos_clients} - client = clients[session._pinned_address] - self.configure_fail_point(client, fail_point) - self.addCleanup(self.set_fail_point, {"mode": "off"}) - - def assert_session_pinned(self, session): - """Run the assertSessionPinned test operation. - - Assert that the given session is pinned. - """ - self.assertIsNotNone(session._transaction.pinned_address) - - def assert_session_unpinned(self, session): - """Run the assertSessionUnpinned test operation. - - Assert that the given session is not pinned. - """ - self.assertIsNone(session._pinned_address) - self.assertIsNone(session._transaction.pinned_address) - - def assert_collection_exists(self, database, collection): - """Run the assertCollectionExists test operation.""" - db = self.client[database] - self.assertIn(collection, db.list_collection_names()) - - def assert_collection_not_exists(self, database, collection): - """Run the assertCollectionNotExists test operation.""" - db = self.client[database] - self.assertNotIn(collection, db.list_collection_names()) - - def assert_index_exists(self, database, collection, index): - """Run the assertIndexExists test operation.""" - coll = self.client[database][collection] - self.assertIn(index, [doc["name"] for doc in coll.list_indexes()]) - - def assert_index_not_exists(self, database, collection, index): - """Run the assertIndexNotExists test operation.""" - coll = self.client[database][collection] - self.assertNotIn(index, [doc["name"] for doc in coll.list_indexes()]) - - def wait(self, ms): - """Run the "wait" test operation.""" - time.sleep(ms / 1000.0) - - def assertErrorLabelsContain(self, exc, expected_labels): - labels = [l for l in expected_labels if exc.has_error_label(l)] - self.assertEqual(labels, expected_labels) - - def assertErrorLabelsOmit(self, exc, omit_labels): - for label in omit_labels: - self.assertFalse( - exc.has_error_label(label), msg=f"error labels should not contain {label}" - ) - - def kill_all_sessions(self): - clients = self.mongos_clients if self.mongos_clients else [self.client] - for client in clients: - try: - client.admin.command("killAllSessions", []) - except (OperationFailure, AutoReconnect): - # "operation was interrupted" by killing the command's - # own session. - # On 8.0+ killAllSessions sometimes returns a network error. - pass - - def check_command_result(self, expected_result, result): - # Only compare the keys in the expected result. - filtered_result = {} - for key in expected_result: - try: - filtered_result[key] = result[key] - except KeyError: - pass - self.assertEqual(filtered_result, expected_result) - - # TODO: factor the following function with test_crud.py. - def check_result(self, expected_result, result): - if isinstance(result, _WriteResult): - for res in expected_result: - prop = camel_to_snake(res) - # SPEC-869: Only BulkWriteResult has upserted_count. - if prop == "upserted_count" and not isinstance(result, BulkWriteResult): - if result.upserted_id is not None: - upserted_count = 1 - else: - upserted_count = 0 - self.assertEqual(upserted_count, expected_result[res], prop) - elif prop == "inserted_ids": - # BulkWriteResult does not have inserted_ids. - if isinstance(result, BulkWriteResult): - self.assertEqual(len(expected_result[res]), result.inserted_count) - else: - # InsertManyResult may be compared to [id1] from the - # crud spec or {"0": id1} from the retryable write spec. - ids = expected_result[res] - if isinstance(ids, dict): - ids = [ids[str(i)] for i in range(len(ids))] - - self.assertEqual(ids, result.inserted_ids, prop) - elif prop == "upserted_ids": - # Convert indexes from strings to integers. - ids = expected_result[res] - expected_ids = {} - for str_index in ids: - expected_ids[int(str_index)] = ids[str_index] - self.assertEqual(expected_ids, result.upserted_ids, prop) - else: - self.assertEqual(getattr(result, prop), expected_result[res], prop) - - return True - else: - - def _helper(expected_result, result): - if isinstance(expected_result, abc.Mapping): - for i in expected_result.keys(): - self.assertEqual(expected_result[i], result[i]) - - elif isinstance(expected_result, list): - for i, k in zip(expected_result, result): - _helper(i, k) - else: - self.assertEqual(expected_result, result) - - _helper(expected_result, result) - return None - - def get_object_name(self, op): - """Allow subclasses to override handling of 'object' - - Transaction spec says 'object' is required. - """ - return op["object"] - - @staticmethod - def parse_options(opts): - return parse_spec_options(opts) - - def run_operation(self, sessions, collection, operation): - original_collection = collection - name = camel_to_snake(operation["name"]) - if name == "run_command": - name = "command" - elif name == "download_by_name": - name = "open_download_stream_by_name" - elif name == "download": - name = "open_download_stream" - elif name == "map_reduce": - self.skipTest("PyMongo does not support mapReduce") - elif name == "count": - self.skipTest("PyMongo does not support count") - - database = collection.database - collection = database.get_collection(collection.name) - if "collectionOptions" in operation: - collection = collection.with_options( - **self.parse_options(operation["collectionOptions"]) - ) - - object_name = self.get_object_name(operation) - if object_name == "gridfsbucket": - # Only create the GridFSBucket when we need it (for the gridfs - # retryable reads tests). - obj = GridFSBucket(database, bucket_name=collection.name) - else: - objects = { - "client": database.client, - "database": database, - "collection": collection, - "testRunner": self, - } - objects.update(sessions) - obj = objects[object_name] - - # Combine arguments with options and handle special cases. - arguments = operation.get("arguments", {}) - arguments.update(arguments.pop("options", {})) - self.parse_options(arguments) - - cmd = getattr(obj, name) - - with_txn_callback = functools.partial( - self.run_operations, sessions, original_collection, in_with_transaction=True - ) - prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback) - - if name == "run_on_thread": - args = {"sessions": sessions, "collection": collection} - args.update(arguments) - arguments = args - - if not _IS_SYNC and iscoroutinefunction(cmd): - result = cmd(**dict(arguments)) - else: - result = cmd(**dict(arguments)) - # Cleanup open change stream cursors. - if name == "watch": - self.addCleanup(result.close) - - if name == "aggregate": - if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: - # Read from the primary to ensure causal consistency. - out = collection.database.get_collection( - arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY - ) - return out.find() - if "download" in name: - result = Binary(result.read()) - - if isinstance(result, Cursor) or isinstance(result, CommandCursor): - return result.to_list() - - return result - - def allowable_errors(self, op): - """Allow encryption spec to override expected error classes.""" - return (PyMongoError,) - - def _run_op(self, sessions, collection, op, in_with_transaction): - expected_result = op.get("result") - if expect_error(op): - with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context: - self.run_operation(sessions, collection, op.copy()) - exc = context.exception - if expect_error_message(expected_result): - if isinstance(exc, BulkWriteError): - errmsg = str(exc.details).lower() - else: - errmsg = str(exc).lower() - self.assertIn(expected_result["errorContains"].lower(), errmsg) - if expect_error_code(expected_result): - self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName")) - if expect_error_labels_contain(expected_result): - self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"]) - if expect_error_labels_omit(expected_result): - self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"]) - if expect_timeout_error(expected_result): - self.assertIsInstance(exc, PyMongoError) - if not exc.timeout: - # Re-raise the exception for better diagnostics. - raise exc - - # Reraise the exception if we're in the with_transaction - # callback. - if in_with_transaction: - raise context.exception - else: - result = self.run_operation(sessions, collection, op.copy()) - if "result" in op: - if op["name"] == "runCommand": - self.check_command_result(expected_result, result) - else: - self.check_result(expected_result, result) - - def run_operations(self, sessions, collection, ops, in_with_transaction=False): - for op in ops: - self._run_op(sessions, collection, op, in_with_transaction) - - # TODO: factor with test_command_monitoring.py - def check_events(self, test, listener, session_ids): - events = listener.started_events - if not len(test["expectations"]): - return - - # Give a nicer message when there are missing or extra events - cmds = decode_raw([event.command for event in events]) - self.assertEqual(len(events), len(test["expectations"]), cmds) - for i, expectation in enumerate(test["expectations"]): - event_type = next(iter(expectation)) - event = events[i] - - # The tests substitute 42 for any number other than 0. - if event.command_name == "getMore" and event.command["getMore"]: - event.command["getMore"] = Int64(42) - elif event.command_name == "killCursors": - event.command["cursors"] = [Int64(42)] - elif event.command_name == "update": - # TODO: remove this once PYTHON-1744 is done. - # Add upsert and multi fields back into expectations. - updates = expectation[event_type]["command"]["updates"] - for update in updates: - update.setdefault("upsert", False) - update.setdefault("multi", False) - - # Replace afterClusterTime: 42 with actual afterClusterTime. - expected_cmd = expectation[event_type]["command"] - expected_read_concern = expected_cmd.get("readConcern") - if expected_read_concern is not None: - time = expected_read_concern.get("afterClusterTime") - if time == 42: - actual_time = event.command.get("readConcern", {}).get("afterClusterTime") - if actual_time is not None: - expected_read_concern["afterClusterTime"] = actual_time - - recovery_token = expected_cmd.get("recoveryToken") - if recovery_token == 42: - expected_cmd["recoveryToken"] = CompareType(dict) - - # Replace lsid with a name like "session0" to match test. - if "lsid" in event.command: - for name, lsid in session_ids.items(): - if event.command["lsid"] == lsid: - event.command["lsid"] = name - break - - for attr, expected in expectation[event_type].items(): - actual = getattr(event, attr) - expected = wrap_types(expected) - if isinstance(expected, dict): - for key, val in expected.items(): - if val is None: - if key in actual: - self.fail(f"Unexpected key [{key}] in {actual!r}") - elif key not in actual: - self.fail(f"Expected key [{key}] in {actual!r}") - else: - self.assertEqual( - val, decode_raw(actual[key]), f"Key [{key}] in {actual}" - ) - else: - self.assertEqual(actual, expected) - - def maybe_skip_scenario(self, test): - if test.get("skipReason"): - self.skipTest(test.get("skipReason")) - - def get_scenario_db_name(self, scenario_def): - """Allow subclasses to override a test's database name.""" - return scenario_def["database_name"] - - def get_scenario_coll_name(self, scenario_def): - """Allow subclasses to override a test's collection name.""" - return scenario_def["collection_name"] - - def get_outcome_coll_name(self, outcome, collection): - """Allow subclasses to override outcome collection.""" - return collection.name - - def run_test_ops(self, sessions, collection, test): - """Added to allow retryable writes spec to override a test's - operation. - """ - self.run_operations(sessions, collection, test["operations"]) - - def parse_client_options(self, opts): - """Allow encryption spec to override a clientOptions parsing.""" - return opts - - def setup_scenario(self, scenario_def): - """Allow specs to override a test's setup.""" - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - documents = scenario_def["data"] - - # Setup the collection with as few majority writes as possible. - db = client_context.client.get_database(db_name) - coll_exists = bool(db.list_collection_names(filter={"name": coll_name})) - if coll_exists: - db[coll_name].delete_many({}) - # Only use majority wc only on the final write. - wc = WriteConcern(w="majority") - if documents: - db.get_collection(coll_name, write_concern=wc).insert_many(documents) - elif not coll_exists: - # Ensure collection exists. - db.create_collection(coll_name, write_concern=wc) - - def run_scenario(self, scenario_def, test): - self.maybe_skip_scenario(test) - - # Kill all sessions after each test with transactions to prevent an open - # transaction (from a test failure) from blocking collection/database - # operations during test set up and tear down. - for op in test["operations"]: - name = op["name"] - if name == "startTransaction" or name == "withTransaction": - self.addCleanup(self.kill_all_sessions) - break - self.setup_scenario(scenario_def) - database_name = self.get_scenario_db_name(scenario_def) - collection_name = self.get_scenario_coll_name(scenario_def) - # SPEC-1245 workaround StaleDbVersion on distinct - for c in self.mongos_clients: - c[database_name][collection_name].distinct("x") - - # Configure the fail point before creating the client. - if "failPoint" in test: - fp = test["failPoint"] - self.set_fail_point(fp) - self.addCleanup( - self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} - ) - - listener = OvertCommandListener() - pool_listener = CMAPListener() - server_listener = ServerAndTopologyEventListener() - # Create a new client, to avoid interference from pooled sessions. - client_options = self.parse_client_options(test["clientOptions"]) - use_multi_mongos = test["useMultipleMongoses"] - host = None - if use_multi_mongos: - if client_context.load_balancer: - host = client_context.MULTI_MONGOS_LB_URI - elif client_context.is_mongos: - host = client_context.mongos_seeds() - client = self.rs_client( - h=host, event_listeners=[listener, pool_listener, server_listener], **client_options - ) - self.scenario_client = client - self.listener = listener - self.pool_listener = pool_listener - self.server_listener = server_listener - - # Create session0 and session1. - sessions = {} - session_ids = {} - for i in range(2): - # Don't attempt to create sessions if they are not supported by - # the running server version. - if not client_context.sessions_enabled: - break - session_name = "session%d" % i - opts = camel_to_snake_args(test["sessionOptions"][session_name]) - if "default_transaction_options" in opts: - txn_opts = self.parse_options(opts["default_transaction_options"]) - txn_opts = client_session.TransactionOptions(**txn_opts) - opts["default_transaction_options"] = txn_opts - - s = client.start_session(**dict(opts)) - - sessions[session_name] = s - # Store lsid so we can access it after end_session, in check_events. - session_ids[session_name] = s.session_id - - self.addCleanup(end_sessions, sessions) - - collection = client[database_name][collection_name] - self.run_test_ops(sessions, collection, test) - - end_sessions(sessions) - - self.check_events(test, listener, session_ids) - - # Disable fail points. - if "failPoint" in test: - fp = test["failPoint"] - self.set_fail_point({"configureFailPoint": fp["configureFailPoint"], "mode": "off"}) - - # Assert final state is expected. - outcome = test["outcome"] - expected_c = outcome.get("collection") - if expected_c is not None: - outcome_coll_name = self.get_outcome_coll_name(outcome, collection) - - # Read from the primary with local read concern to ensure causal - # consistency. - outcome_coll = client_context.client[collection.database.name].get_collection( - outcome_coll_name, - read_preference=ReadPreference.PRIMARY, - read_concern=ReadConcern("local"), - ) - actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list() - - # The expected data needs to be the left hand side here otherwise - # CompareType(Binary) doesn't work. - self.assertEqual(wrap_types(expected_c["data"]), actual_data) - - -def expect_any_error(op): - if isinstance(op, dict): - return op.get("error") - - return False - - -def expect_error_message(expected_result): - if isinstance(expected_result, dict): - return isinstance(expected_result["errorContains"], str) - - return False - - -def expect_error_code(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorCodeName"] - - return False - - -def expect_error_labels_contain(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorLabelsContain"] - - return False - - -def expect_error_labels_omit(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorLabelsOmit"] - - return False - - -def expect_timeout_error(expected_result): - if isinstance(expected_result, dict): - return expected_result["isTimeoutError"] - - return False - - -def expect_error(op): - expected_result = op.get("result") - return ( - expect_any_error(op) - or expect_error_message(expected_result) - or expect_error_code(expected_result) - or expect_error_labels_contain(expected_result) - or expect_error_labels_omit(expected_result) - or expect_timeout_error(expected_result) - ) - - -def end_sessions(sessions): - for s in sessions.values(): - # Aborts the transaction if it's open. - s.end_session() - - -def decode_raw(val): - """Decode RawBSONDocuments in the given container.""" - if isinstance(val, (list, abc.Mapping)): - return decode(encode({"v": val}))["v"] - return val - - -TYPES = { - "binData": Binary, - "long": Int64, - "int": int, - "string": str, - "objectId": ObjectId, - "object": dict, - "array": list, -} - - -def wrap_types(val): - """Support $$type assertion in command results.""" - if isinstance(val, list): - return [wrap_types(v) for v in val] - if isinstance(val, abc.Mapping): - typ = val.get("$$type") - if typ: - if isinstance(typ, str): - types = TYPES[typ] - else: - types = tuple(TYPES[t] for t in typ) - return CompareType(types) - d = {} - for key in val: - d[key] = wrap_types(val[key]) - return d - return val