From 853c82540283bf08a78cc5f50c287a2ff136d06c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 1 Apr 2026 15:02:18 -0700 Subject: [PATCH 1/4] chore(firestore): skip pipeline verification tests outside enterprise db --- .../google-cloud-firestore/tests/system/test_system.py | 8 ++++++-- .../tests/system/test_system_async.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/packages/google-cloud-firestore/tests/system/test_system.py b/packages/google-cloud-firestore/tests/system/test_system.py index f3cdf13b09a1..c28c4dcc363d 100644 --- a/packages/google-cloud-firestore/tests/system/test_system.py +++ b/packages/google-cloud-firestore/tests/system/test_system.py @@ -94,8 +94,13 @@ def verify_pipeline(query): """ from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + client = query._client if FIRESTORE_EMULATOR: - pytest.skip("skip pipeline verification on emulator") + print("skip pipeline verification on emulator") + return + if client._database != FIRESTORE_ENTERPRISE_DB: + print("pipelines only supports enterprise db") + return def _clean_results(results): if isinstance(results, dict): @@ -126,7 +131,6 @@ def _clean_results(results): except Exception as e: # if we expect the query to fail, capture the exception query_exception = e - client = query._client pipeline = client.pipeline().create_from(query) if query_exception: # ensure that the pipeline uses same error as query diff --git a/packages/google-cloud-firestore/tests/system/test_system_async.py b/packages/google-cloud-firestore/tests/system/test_system_async.py index b2806a0dc68e..661b5e20b96a 100644 --- a/packages/google-cloud-firestore/tests/system/test_system_async.py +++ b/packages/google-cloud-firestore/tests/system/test_system_async.py @@ -174,8 +174,13 @@ async def verify_pipeline(query): """ from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + client = query._client if FIRESTORE_EMULATOR: - pytest.skip("skip pipeline verification on emulator") + print("skip pipeline verification on emulator") + return + if client._database != FIRESTORE_ENTERPRISE_DB: + print("pipelines only supports enterprise db") + return def _clean_results(results): if isinstance(results, dict): @@ -206,7 +211,6 @@ def _clean_results(results): except Exception as e: # if we expect the query to fail, capture the exception query_exception = e - client = query._client pipeline = client.pipeline().create_from(query) if query_exception: # ensure that the pipeline uses same error as query From af55183089b0225143bfe7eb4bd5b6a9a27c8005 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 1 Apr 2026 15:04:30 -0700 Subject: [PATCH 2/4] added docstring --- packages/google-cloud-firestore/tests/system/test_system.py | 2 ++ .../google-cloud-firestore/tests/system/test_system_async.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/packages/google-cloud-firestore/tests/system/test_system.py b/packages/google-cloud-firestore/tests/system/test_system.py index c28c4dcc363d..d61def9fdaf9 100644 --- a/packages/google-cloud-firestore/tests/system/test_system.py +++ b/packages/google-cloud-firestore/tests/system/test_system.py @@ -91,6 +91,8 @@ def verify_pipeline(query): It can be attached to existing query tests to check both modalities at the same time + + Pipelines are only supported on enterprise dbs. Skip other environments """ from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery diff --git a/packages/google-cloud-firestore/tests/system/test_system_async.py b/packages/google-cloud-firestore/tests/system/test_system_async.py index 661b5e20b96a..0348689e63ad 100644 --- a/packages/google-cloud-firestore/tests/system/test_system_async.py +++ b/packages/google-cloud-firestore/tests/system/test_system_async.py @@ -171,6 +171,8 @@ async def verify_pipeline(query): It can be attached to existing query tests to check both modalities at the same time + + Pipelines are only supported on enterprise dbs. Skip other environments """ from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery From 32b4db025f063b77e5f7a609115f8b8013b1a3be Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 1 Apr 2026 15:45:23 -0700 Subject: [PATCH 3/4] use subtests for verify_pipeline --- packages/google-cloud-firestore/noxfile.py | 2 +- .../tests/system/test_system.py | 156 +++++++++--------- .../tests/system/test_system_async.py | 144 ++++++++-------- 3 files changed, 158 insertions(+), 144 deletions(-) diff --git a/packages/google-cloud-firestore/noxfile.py b/packages/google-cloud-firestore/noxfile.py index 588dd7c0058d..5a7c0a1b8536 100644 --- a/packages/google-cloud-firestore/noxfile.py +++ b/packages/google-cloud-firestore/noxfile.py @@ -71,7 +71,7 @@ SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ALL_PYTHON SYSTEM_TEST_STANDARD_DEPENDENCIES = [ "mock", - "pytest", + "pytest>9.0", "google-cloud-testutils", ] SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ diff --git a/packages/google-cloud-firestore/tests/system/test_system.py b/packages/google-cloud-firestore/tests/system/test_system.py index d61def9fdaf9..66822ef7e585 100644 --- a/packages/google-cloud-firestore/tests/system/test_system.py +++ b/packages/google-cloud-firestore/tests/system/test_system.py @@ -83,69 +83,73 @@ def cleanup(): for operation in operations: operation() - -def verify_pipeline(query): +@pytest.fixture +def verify_pipeline(subtests): """ - This function ensures a pipeline produces the same - results as the query it is derived from + This fixture provide a subtest function which + ensures a pipeline produces the same results as the query it is derived + from It can be attached to existing query tests to check both modalities at the same time Pipelines are only supported on enterprise dbs. Skip other environments """ - from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery - - client = query._client - if FIRESTORE_EMULATOR: - print("skip pipeline verification on emulator") - return - if client._database != FIRESTORE_ENTERPRISE_DB: - print("pipelines only supports enterprise db") - return - - def _clean_results(results): - if isinstance(results, dict): - return {k: _clean_results(v) for k, v in results.items()} - elif isinstance(results, list): - return [_clean_results(r) for r in results] - elif isinstance(results, float) and math.isnan(results): - return "__NAN_VALUE__" - else: - return results - query_exception = None - query_results = None - try: - try: - if isinstance(query, BaseAggregationQuery): - # aggregation queries return a list of lists of aggregation results - query_results = _clean_results( - list( - itertools.chain.from_iterable( - [[a._to_dict() for a in s] for s in query.get()] + def _verifier(query): + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + with subtests.test(msg="verify_pipeline"): + + client = query._client + if FIRESTORE_EMULATOR: + pytest.skip("skip pipeline verification on emulator") + if client._database != FIRESTORE_ENTERPRISE_DB: + pytest.skip("pipelines only supports enterprise db") + + def _clean_results(results): + if isinstance(results, dict): + return {k: _clean_results(v) for k, v in results.items()} + elif isinstance(results, list): + return [_clean_results(r) for r in results] + elif isinstance(results, float) and math.isnan(results): + return "__NAN_VALUE__" + else: + return results + + query_exception = None + query_results = None + try: + try: + if isinstance(query, BaseAggregationQuery): + # aggregation queries return a list of lists of aggregation results + query_results = _clean_results( + list( + itertools.chain.from_iterable( + [[a._to_dict() for a in s] for s in query.get()] + ) + ) ) - ) - ) - else: - # other qureies return a simple list of results - query_results = _clean_results([s.to_dict() for s in query.get()]) - except Exception as e: - # if we expect the query to fail, capture the exception - query_exception = e - pipeline = client.pipeline().create_from(query) - if query_exception: - # ensure that the pipeline uses same error as query - with pytest.raises(query_exception.__class__): - pipeline.execute() - else: - # ensure results match query - pipeline_results = _clean_results([s.data() for s in pipeline.execute()]) - assert query_results == pipeline_results - except FailedPrecondition as e: - # if testing against a non-enterprise db, skip this check - if ENTERPRISE_MODE_ERROR not in e.message: - raise e + else: + # other qureies return a simple list of results + query_results = _clean_results([s.to_dict() for s in query.get()]) + except Exception as e: + # if we expect the query to fail, capture the exception + query_exception = e + pipeline = client.pipeline().create_from(query) + if query_exception: + # ensure that the pipeline uses same error as query + with pytest.raises(query_exception.__class__): + pipeline.execute() + else: + # ensure results match query + pipeline_results = _clean_results([s.data() for s in pipeline.execute()]) + assert query_results == pipeline_results + except FailedPrecondition as e: + # if testing against a non-enterprise db, skip this check + if ENTERPRISE_MODE_ERROR not in e.message: + raise e + + return _verifier @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1300,7 +1304,7 @@ def query(collection): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_legacy_where(query_docs, database): +def test_query_stream_legacy_where(query_docs, database, verify_pipeline): """Assert the legacy code still works and returns value""" collection, stored, allowed_vals = query_docs with pytest.warns( @@ -1317,7 +1321,7 @@ def test_query_stream_legacy_where(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_simple_field_eq_op(query_docs, database): +def test_query_stream_w_simple_field_eq_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1329,7 +1333,7 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_simple_field_array_contains_op(query_docs, database): +def test_query_stream_w_simple_field_array_contains_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1341,7 +1345,7 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_simple_field_in_op(query_docs, database): +def test_query_stream_w_simple_field_in_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("a", "in", [1, num_vals + 100])) @@ -1354,7 +1358,7 @@ def test_query_stream_w_simple_field_in_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_not_eq_op(query_docs, database): +def test_query_stream_w_not_eq_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", "!=", 4)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1377,7 +1381,7 @@ def test_query_stream_w_not_eq_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_simple_not_in_op(query_docs, database): +def test_query_stream_w_simple_not_in_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where( @@ -1390,7 +1394,7 @@ def test_query_stream_w_simple_not_in_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): +def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where( @@ -1405,7 +1409,7 @@ def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_order_by(query_docs, database): +def test_query_stream_w_order_by(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) values = [(snapshot.id, snapshot.to_dict()) for snapshot in query.stream()] @@ -1420,7 +1424,7 @@ def test_query_stream_w_order_by(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_field_path(query_docs, database): +def test_query_stream_w_field_path(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1459,7 +1463,7 @@ def test_query_stream_w_start_end_cursor(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_wo_results(query_docs, database): +def test_query_stream_wo_results(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) @@ -1486,7 +1490,7 @@ def test_query_stream_w_projection(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_multiple_filters(query_docs, database): +def test_query_stream_w_multiple_filters(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( filter=FieldFilter("stats.product", "<", 10) @@ -1507,7 +1511,7 @@ def test_query_stream_w_multiple_filters(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_offset(query_docs, database): +def test_query_stream_w_offset(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) offset = 3 @@ -1528,7 +1532,7 @@ def test_query_stream_w_offset(query_docs, database): ) @pytest.mark.parametrize("method", ["stream", "get"]) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): +def test_query_stream_or_get_w_no_explain_options(query_docs, database, method, verify_pipeline): from google.cloud.firestore_v1.query_profile import QueryExplainError collection, _, allowed_vals = query_docs @@ -1892,7 +1896,7 @@ def test_query_with_order_dot_key(client, cleanup, database): @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) -def test_query_unary(client, cleanup, database): +def test_query_unary(client, cleanup, database, verify_pipeline): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) field_name = "foo" @@ -1949,7 +1953,7 @@ def test_query_unary(client, cleanup, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_collection_group_queries(client, cleanup, database): +def test_collection_group_queries(client, cleanup, database, verify_pipeline): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ @@ -2026,7 +2030,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_collection_group_queries_filters(client, cleanup, database): +def test_collection_group_queries_filters(client, cleanup, database, verify_pipeline): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ @@ -2817,7 +2821,7 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_repro_429(client, cleanup, database): +def test_repro_429(client, cleanup, database, verify_pipeline): # See: https://github.com/googleapis/python-firestore/issues/429 now = datetime.datetime.now(tz=datetime.timezone.utc) collection = client.collection("repro-429" + UNIQUE_RESOURCE_ID) @@ -3412,7 +3416,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_false( @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_with_and_composite_filter(collection, database): +def test_query_with_and_composite_filter(collection, database, verify_pipeline): and_filter = And( filters=[ FieldFilter("stats.product", ">", 5), @@ -3428,7 +3432,7 @@ def test_query_with_and_composite_filter(collection, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_with_or_composite_filter(collection, database): +def test_query_with_or_composite_filter(collection, database, verify_pipeline): or_filter = Or( filters=[ FieldFilter("stats.product", ">", 5), @@ -3462,6 +3466,7 @@ def test_aggregation_queries_with_read_time( database, aggregation_type, expected_value, + verify_pipeline, ): """ Ensure that all aggregation queries work when read_time is passed into @@ -3500,7 +3505,7 @@ def test_aggregation_queries_with_read_time( @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_with_complex_composite_filter(collection, database): +def test_query_with_complex_composite_filter(collection, database, verify_pipeline): field_filter = FieldFilter("b", "==", 0) or_filter = Or( filters=[FieldFilter("stats.sum", "==", 0), FieldFilter("stats.sum", "==", 4)] @@ -3558,6 +3563,7 @@ def test_aggregation_query_in_transaction( aggregation_type, aggregation_args, expected, + verify_pipeline, ): """ Test creating an aggregation query inside a transaction @@ -3599,7 +3605,7 @@ def in_transaction(transaction): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_or_query_in_transaction(client, cleanup, database): +def test_or_query_in_transaction(client, cleanup, database, verify_pipeline): """ Test running or query inside a transaction. Should pass transaction id along with request """ diff --git a/packages/google-cloud-firestore/tests/system/test_system_async.py b/packages/google-cloud-firestore/tests/system/test_system_async.py index 0348689e63ad..c81dfb9e21e1 100644 --- a/packages/google-cloud-firestore/tests/system/test_system_async.py +++ b/packages/google-cloud-firestore/tests/system/test_system_async.py @@ -164,70 +164,78 @@ async def cleanup(): await operation() -async def verify_pipeline(query): +@pytest.fixture +def verify_pipeline(subtests): """ - This function ensures a pipeline produces the same - results as the query it is derived from + This fixture provide a subtest function which + ensures a pipeline produces the same results as the query it is derived + from It can be attached to existing query tests to check both modalities at the same time Pipelines are only supported on enterprise dbs. Skip other environments """ - from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery - client = query._client - if FIRESTORE_EMULATOR: - print("skip pipeline verification on emulator") - return - if client._database != FIRESTORE_ENTERPRISE_DB: - print("pipelines only supports enterprise db") - return - - def _clean_results(results): - if isinstance(results, dict): - return {k: _clean_results(v) for k, v in results.items()} - elif isinstance(results, list): - return [_clean_results(r) for r in results] - elif isinstance(results, float) and math.isnan(results): - return "__NAN_VALUE__" - else: - return results - - query_exception = None - query_results = None - try: - try: - if isinstance(query, BaseAggregationQuery): - # aggregation queries return a list of lists of aggregation results - query_results = _clean_results( - list( - itertools.chain.from_iterable( - [[a._to_dict() for a in s] for s in await query.get()] + async def _verifier(query): + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + + with subtests.test(msg="verify_pipeline"): + + client = query._client + if FIRESTORE_EMULATOR: + pytest.skip("skip pipeline verification on emulator") + if client._database != FIRESTORE_ENTERPRISE_DB: + pytest.skip("pipelines only supports enterprise db") + + def _clean_results(results): + if isinstance(results, dict): + return {k: _clean_results(v) for k, v in results.items()} + elif isinstance(results, list): + return [_clean_results(r) for r in results] + elif isinstance(results, float) and math.isnan(results): + return "__NAN_VALUE__" + else: + return results + + query_exception = None + query_results = None + try: + try: + if isinstance(query, BaseAggregationQuery): + # aggregation queries return a list of lists of aggregation results + query_results = _clean_results( + list( + itertools.chain.from_iterable( + [[a._to_dict() for a in s] for s in await query.get()] + ) + ) ) + else: + # other qureies return a simple list of results + query_results = _clean_results( + [s.to_dict() for s in await query.get()] + ) + except Exception as e: + # if we expect the query to fail, capture the exception + query_exception = e + pipeline = client.pipeline().create_from(query) + if query_exception: + # ensure that the pipeline uses same error as query + with pytest.raises(query_exception.__class__): + await pipeline.execute() + else: + # ensure results match query + pipeline_results = _clean_results( + [s.data() async for s in pipeline.stream()] ) - ) - else: - # other qureies return a simple list of results - query_results = _clean_results([s.to_dict() for s in await query.get()]) - except Exception as e: - # if we expect the query to fail, capture the exception - query_exception = e - pipeline = client.pipeline().create_from(query) - if query_exception: - # ensure that the pipeline uses same error as query - with pytest.raises(query_exception.__class__): - await pipeline.execute() - else: - # ensure results match query - pipeline_results = _clean_results( - [s.data() async for s in pipeline.stream()] - ) - assert query_results == pipeline_results - except FailedPrecondition as e: - # if testing against a non-enterprise db, skip this check - if ENTERPRISE_MODE_ERROR not in e.message: - raise e + assert query_results == pipeline_results + except FailedPrecondition as e: + # if testing against a non-enterprise db, skip this check + if ENTERPRISE_MODE_ERROR not in e.message: + raise e + + return _verifier @pytest.fixture(scope="module") @@ -1274,7 +1282,7 @@ async def async_query(collection): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_legacy_where(query_docs, database): +async def test_query_stream_legacy_where(query_docs, database, verify_pipeline): """Assert the legacy code still works and returns value, and shows UserWarning""" collection, stored, allowed_vals = query_docs with pytest.warns( @@ -1291,7 +1299,7 @@ async def test_query_stream_legacy_where(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_simple_field_eq_op(query_docs, database): +async def test_query_stream_w_simple_field_eq_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} @@ -1303,7 +1311,7 @@ async def test_query_stream_w_simple_field_eq_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_simple_field_array_contains_op(query_docs, database): +async def test_query_stream_w_simple_field_array_contains_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} @@ -1315,7 +1323,7 @@ async def test_query_stream_w_simple_field_array_contains_op(query_docs, databas @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_simple_field_in_op(query_docs, database): +async def test_query_stream_w_simple_field_in_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("a", "in", [1, num_vals + 100])) @@ -1328,7 +1336,7 @@ async def test_query_stream_w_simple_field_in_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): +async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where( @@ -1343,7 +1351,7 @@ async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, dat @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_order_by(query_docs, database): +async def test_query_stream_w_order_by(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) values = [(snapshot.id, snapshot.to_dict()) async for snapshot in query.stream()] @@ -1358,7 +1366,7 @@ async def test_query_stream_w_order_by(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_field_path(query_docs, database): +async def test_query_stream_w_field_path(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} @@ -1397,7 +1405,7 @@ async def test_query_stream_w_start_end_cursor(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_wo_results(query_docs, database): +async def test_query_stream_wo_results(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) @@ -1424,7 +1432,7 @@ async def test_query_stream_w_projection(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_multiple_filters(query_docs, database): +async def test_query_stream_w_multiple_filters(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( "stats.product", "<", 10 @@ -1445,7 +1453,7 @@ async def test_query_stream_w_multiple_filters(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_offset(query_docs, database): +async def test_query_stream_w_offset(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) offset = 3 @@ -1466,7 +1474,7 @@ async def test_query_stream_w_offset(query_docs, database): ) @pytest.mark.parametrize("method", ["stream", "get"]) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): +async def test_query_stream_or_get_w_no_explain_options(query_docs, database, method, verify_pipeline): from google.cloud.firestore_v1.query_profile import QueryExplainError collection, _, allowed_vals = query_docs @@ -1818,7 +1826,7 @@ async def test_query_with_order_dot_key(client, cleanup, database): @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) -async def test_query_unary(client, cleanup, database): +async def test_query_unary(client, cleanup, database, verify_pipeline): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) field_name = "foo" @@ -1875,7 +1883,7 @@ async def test_query_unary(client, cleanup, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_collection_group_queries(client, cleanup, database): +async def test_collection_group_queries(client, cleanup, database, verify_pipeline): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ @@ -1952,7 +1960,7 @@ async def test_collection_group_queries_startat_endat(client, cleanup, database) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_collection_group_queries_filters(client, cleanup, database): +async def test_collection_group_queries_filters(client, cleanup, database, verify_pipeline): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ From b35f6f1199da6d07729aa4d30b5efb10987e7141 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 1 Apr 2026 15:48:46 -0700 Subject: [PATCH 4/4] fixed lint --- .../tests/system/test_system.py | 23 ++++++++++++++----- .../tests/system/test_system_async.py | 22 +++++++++++++----- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/packages/google-cloud-firestore/tests/system/test_system.py b/packages/google-cloud-firestore/tests/system/test_system.py index 66822ef7e585..79db645fb1df 100644 --- a/packages/google-cloud-firestore/tests/system/test_system.py +++ b/packages/google-cloud-firestore/tests/system/test_system.py @@ -83,6 +83,7 @@ def cleanup(): for operation in operations: operation() + @pytest.fixture def verify_pipeline(subtests): """ @@ -98,8 +99,8 @@ def verify_pipeline(subtests): def _verifier(query): from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery - with subtests.test(msg="verify_pipeline"): + with subtests.test(msg="verify_pipeline"): client = query._client if FIRESTORE_EMULATOR: pytest.skip("skip pipeline verification on emulator") @@ -131,7 +132,9 @@ def _clean_results(results): ) else: # other qureies return a simple list of results - query_results = _clean_results([s.to_dict() for s in query.get()]) + query_results = _clean_results( + [s.to_dict() for s in query.get()] + ) except Exception as e: # if we expect the query to fail, capture the exception query_exception = e @@ -142,7 +145,9 @@ def _clean_results(results): pipeline.execute() else: # ensure results match query - pipeline_results = _clean_results([s.data() for s in pipeline.execute()]) + pipeline_results = _clean_results( + [s.data() for s in pipeline.execute()] + ) assert query_results == pipeline_results except FailedPrecondition as e: # if testing against a non-enterprise db, skip this check @@ -1333,7 +1338,9 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database, verify_pipeline @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_simple_field_array_contains_op(query_docs, database, verify_pipeline): +def test_query_stream_w_simple_field_array_contains_op( + query_docs, database, verify_pipeline +): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1394,7 +1401,9 @@ def test_query_stream_w_simple_not_in_op(query_docs, database, verify_pipeline): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database, verify_pipeline): +def test_query_stream_w_simple_field_array_contains_any_op( + query_docs, database, verify_pipeline +): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where( @@ -1532,7 +1541,9 @@ def test_query_stream_w_offset(query_docs, database, verify_pipeline): ) @pytest.mark.parametrize("method", ["stream", "get"]) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_or_get_w_no_explain_options(query_docs, database, method, verify_pipeline): +def test_query_stream_or_get_w_no_explain_options( + query_docs, database, method, verify_pipeline +): from google.cloud.firestore_v1.query_profile import QueryExplainError collection, _, allowed_vals = query_docs diff --git a/packages/google-cloud-firestore/tests/system/test_system_async.py b/packages/google-cloud-firestore/tests/system/test_system_async.py index c81dfb9e21e1..3a7959830425 100644 --- a/packages/google-cloud-firestore/tests/system/test_system_async.py +++ b/packages/google-cloud-firestore/tests/system/test_system_async.py @@ -181,7 +181,6 @@ async def _verifier(query): from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery with subtests.test(msg="verify_pipeline"): - client = query._client if FIRESTORE_EMULATOR: pytest.skip("skip pipeline verification on emulator") @@ -207,7 +206,10 @@ def _clean_results(results): query_results = _clean_results( list( itertools.chain.from_iterable( - [[a._to_dict() for a in s] for s in await query.get()] + [ + [a._to_dict() for a in s] + for s in await query.get() + ] ) ) ) @@ -1311,7 +1313,9 @@ async def test_query_stream_w_simple_field_eq_op(query_docs, database, verify_pi @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_simple_field_array_contains_op(query_docs, database, verify_pipeline): +async def test_query_stream_w_simple_field_array_contains_op( + query_docs, database, verify_pipeline +): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} @@ -1336,7 +1340,9 @@ async def test_query_stream_w_simple_field_in_op(query_docs, database, verify_pi @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database, verify_pipeline): +async def test_query_stream_w_simple_field_array_contains_any_op( + query_docs, database, verify_pipeline +): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where( @@ -1474,7 +1480,9 @@ async def test_query_stream_w_offset(query_docs, database, verify_pipeline): ) @pytest.mark.parametrize("method", ["stream", "get"]) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_or_get_w_no_explain_options(query_docs, database, method, verify_pipeline): +async def test_query_stream_or_get_w_no_explain_options( + query_docs, database, method, verify_pipeline +): from google.cloud.firestore_v1.query_profile import QueryExplainError collection, _, allowed_vals = query_docs @@ -1960,7 +1968,9 @@ async def test_collection_group_queries_startat_endat(client, cleanup, database) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_collection_group_queries_filters(client, cleanup, database, verify_pipeline): +async def test_collection_group_queries_filters( + client, cleanup, database, verify_pipeline +): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [