Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/google-cloud-firestore/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ALL_PYTHON
SYSTEM_TEST_STANDARD_DEPENDENCIES = [
"mock",
"pytest",
"pytest>9.0",
"google-cloud-testutils",
]
SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [
Expand Down
161 changes: 92 additions & 69 deletions packages/google-cloud-firestore/tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,62 +84,77 @@ def cleanup():
operation()


def verify_pipeline(query):
@pytest.fixture
def verify_pipeline(subtests):
"""
This function ensures a pipeline produces the same
results as the query it is derived from
This fixture provide a subtest function which
ensures a pipeline produces the same results as the query it is derived
from

It can be attached to existing query tests to check both
modalities at the same time
"""
from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery

if FIRESTORE_EMULATOR:
pytest.skip("skip pipeline verification on emulator")

def _clean_results(results):
if isinstance(results, dict):
return {k: _clean_results(v) for k, v in results.items()}
elif isinstance(results, list):
return [_clean_results(r) for r in results]
elif isinstance(results, float) and math.isnan(results):
return "__NAN_VALUE__"
else:
return results
Pipelines are only supported on enterprise dbs. Skip other environments
"""

query_exception = None
query_results = None
try:
try:
if isinstance(query, BaseAggregationQuery):
# aggregation queries return a list of lists of aggregation results
query_results = _clean_results(
list(
itertools.chain.from_iterable(
[[a._to_dict() for a in s] for s in query.get()]
def _verifier(query):
from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery

with subtests.test(msg="verify_pipeline"):
client = query._client
if FIRESTORE_EMULATOR:
pytest.skip("skip pipeline verification on emulator")
if client._database != FIRESTORE_ENTERPRISE_DB:
pytest.skip("pipelines only supports enterprise db")

def _clean_results(results):
if isinstance(results, dict):
return {k: _clean_results(v) for k, v in results.items()}
elif isinstance(results, list):
return [_clean_results(r) for r in results]
elif isinstance(results, float) and math.isnan(results):
return "__NAN_VALUE__"
else:
return results

query_exception = None
query_results = None
try:
try:
if isinstance(query, BaseAggregationQuery):
# aggregation queries return a list of lists of aggregation results
query_results = _clean_results(
list(
itertools.chain.from_iterable(
[[a._to_dict() for a in s] for s in query.get()]
)
)
)
else:
# other qureies return a simple list of results
query_results = _clean_results(
[s.to_dict() for s in query.get()]
)
except Exception as e:
# if we expect the query to fail, capture the exception
query_exception = e
pipeline = client.pipeline().create_from(query)
if query_exception:
# ensure that the pipeline uses same error as query
with pytest.raises(query_exception.__class__):
pipeline.execute()
else:
# ensure results match query
pipeline_results = _clean_results(
[s.data() for s in pipeline.execute()]
)
)
else:
# other qureies return a simple list of results
query_results = _clean_results([s.to_dict() for s in query.get()])
except Exception as e:
# if we expect the query to fail, capture the exception
query_exception = e
client = query._client
pipeline = client.pipeline().create_from(query)
if query_exception:
# ensure that the pipeline uses same error as query
with pytest.raises(query_exception.__class__):
pipeline.execute()
else:
# ensure results match query
pipeline_results = _clean_results([s.data() for s in pipeline.execute()])
assert query_results == pipeline_results
except FailedPrecondition as e:
# if testing against a non-enterprise db, skip this check
if ENTERPRISE_MODE_ERROR not in e.message:
raise e
assert query_results == pipeline_results
except FailedPrecondition as e:
# if testing against a non-enterprise db, skip this check
if ENTERPRISE_MODE_ERROR not in e.message:
raise e

return _verifier


@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True)
Expand Down Expand Up @@ -1294,7 +1309,7 @@ def query(collection):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_legacy_where(query_docs, database):
def test_query_stream_legacy_where(query_docs, database, verify_pipeline):
"""Assert the legacy code still works and returns value"""
collection, stored, allowed_vals = query_docs
with pytest.warns(
Expand All @@ -1311,7 +1326,7 @@ def test_query_stream_legacy_where(query_docs, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_w_simple_field_eq_op(query_docs, database):
def test_query_stream_w_simple_field_eq_op(query_docs, database, verify_pipeline):
collection, stored, allowed_vals = query_docs
query = collection.where(filter=FieldFilter("a", "==", 1))
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
Expand All @@ -1323,7 +1338,9 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_w_simple_field_array_contains_op(query_docs, database):
def test_query_stream_w_simple_field_array_contains_op(
query_docs, database, verify_pipeline
):
collection, stored, allowed_vals = query_docs
query = collection.where(filter=FieldFilter("c", "array_contains", 1))
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
Expand All @@ -1335,7 +1352,7 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_w_simple_field_in_op(query_docs, database):
def test_query_stream_w_simple_field_in_op(query_docs, database, verify_pipeline):
collection, stored, allowed_vals = query_docs
num_vals = len(allowed_vals)
query = collection.where(filter=FieldFilter("a", "in", [1, num_vals + 100]))
Expand All @@ -1348,7 +1365,7 @@ def test_query_stream_w_simple_field_in_op(query_docs, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_w_not_eq_op(query_docs, database):
def test_query_stream_w_not_eq_op(query_docs, database, verify_pipeline):
collection, stored, allowed_vals = query_docs
query = collection.where(filter=FieldFilter("stats.sum", "!=", 4))
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
Expand All @@ -1371,7 +1388,7 @@ def test_query_stream_w_not_eq_op(query_docs, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_w_simple_not_in_op(query_docs, database):
def test_query_stream_w_simple_not_in_op(query_docs, database, verify_pipeline):
collection, stored, allowed_vals = query_docs
num_vals = len(allowed_vals)
query = collection.where(
Expand All @@ -1384,7 +1401,9 @@ def test_query_stream_w_simple_not_in_op(query_docs, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database):
def test_query_stream_w_simple_field_array_contains_any_op(
query_docs, database, verify_pipeline
):
collection, stored, allowed_vals = query_docs
num_vals = len(allowed_vals)
query = collection.where(
Expand All @@ -1399,7 +1418,7 @@ def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database)


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_w_order_by(query_docs, database):
def test_query_stream_w_order_by(query_docs, database, verify_pipeline):
collection, stored, allowed_vals = query_docs
query = collection.order_by("b", direction=firestore.Query.DESCENDING)
values = [(snapshot.id, snapshot.to_dict()) for snapshot in query.stream()]
Expand All @@ -1414,7 +1433,7 @@ def test_query_stream_w_order_by(query_docs, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_w_field_path(query_docs, database):
def test_query_stream_w_field_path(query_docs, database, verify_pipeline):
collection, stored, allowed_vals = query_docs
query = collection.where(filter=FieldFilter("stats.sum", ">", 4))
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
Expand Down Expand Up @@ -1453,7 +1472,7 @@ def test_query_stream_w_start_end_cursor(query_docs, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_wo_results(query_docs, database):
def test_query_stream_wo_results(query_docs, database, verify_pipeline):
collection, stored, allowed_vals = query_docs
num_vals = len(allowed_vals)
query = collection.where(filter=FieldFilter("b", "==", num_vals + 100))
Expand All @@ -1480,7 +1499,7 @@ def test_query_stream_w_projection(query_docs, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_w_multiple_filters(query_docs, database):
def test_query_stream_w_multiple_filters(query_docs, database, verify_pipeline):
collection, stored, allowed_vals = query_docs
query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where(
filter=FieldFilter("stats.product", "<", 10)
Expand All @@ -1501,7 +1520,7 @@ def test_query_stream_w_multiple_filters(query_docs, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_w_offset(query_docs, database):
def test_query_stream_w_offset(query_docs, database, verify_pipeline):
collection, stored, allowed_vals = query_docs
num_vals = len(allowed_vals)
offset = 3
Expand All @@ -1522,7 +1541,9 @@ def test_query_stream_w_offset(query_docs, database):
)
@pytest.mark.parametrize("method", ["stream", "get"])
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_or_get_w_no_explain_options(query_docs, database, method):
def test_query_stream_or_get_w_no_explain_options(
query_docs, database, method, verify_pipeline
):
from google.cloud.firestore_v1.query_profile import QueryExplainError

collection, _, allowed_vals = query_docs
Expand Down Expand Up @@ -1886,7 +1907,7 @@ def test_query_with_order_dot_key(client, cleanup, database):


@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True)
def test_query_unary(client, cleanup, database):
def test_query_unary(client, cleanup, database, verify_pipeline):
collection_name = "unary" + UNIQUE_RESOURCE_ID
collection = client.collection(collection_name)
field_name = "foo"
Expand Down Expand Up @@ -1943,7 +1964,7 @@ def test_query_unary(client, cleanup, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_collection_group_queries(client, cleanup, database):
def test_collection_group_queries(client, cleanup, database, verify_pipeline):
collection_group = "b" + UNIQUE_RESOURCE_ID

doc_paths = [
Expand Down Expand Up @@ -2020,7 +2041,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_collection_group_queries_filters(client, cleanup, database):
def test_collection_group_queries_filters(client, cleanup, database, verify_pipeline):
collection_group = "b" + UNIQUE_RESOURCE_ID

doc_paths = [
Expand Down Expand Up @@ -2811,7 +2832,7 @@ def on_snapshot(docs, changes, read_time):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_repro_429(client, cleanup, database):
def test_repro_429(client, cleanup, database, verify_pipeline):
# See: https://github.com/googleapis/python-firestore/issues/429
now = datetime.datetime.now(tz=datetime.timezone.utc)
collection = client.collection("repro-429" + UNIQUE_RESOURCE_ID)
Expand Down Expand Up @@ -3406,7 +3427,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_false(


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_with_and_composite_filter(collection, database):
def test_query_with_and_composite_filter(collection, database, verify_pipeline):
and_filter = And(
filters=[
FieldFilter("stats.product", ">", 5),
Expand All @@ -3422,7 +3443,7 @@ def test_query_with_and_composite_filter(collection, database):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_with_or_composite_filter(collection, database):
def test_query_with_or_composite_filter(collection, database, verify_pipeline):
or_filter = Or(
filters=[
FieldFilter("stats.product", ">", 5),
Expand Down Expand Up @@ -3456,6 +3477,7 @@ def test_aggregation_queries_with_read_time(
database,
aggregation_type,
expected_value,
verify_pipeline,
):
"""
Ensure that all aggregation queries work when read_time is passed into
Expand Down Expand Up @@ -3494,7 +3516,7 @@ def test_aggregation_queries_with_read_time(


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_with_complex_composite_filter(collection, database):
def test_query_with_complex_composite_filter(collection, database, verify_pipeline):
field_filter = FieldFilter("b", "==", 0)
or_filter = Or(
filters=[FieldFilter("stats.sum", "==", 0), FieldFilter("stats.sum", "==", 4)]
Expand Down Expand Up @@ -3552,6 +3574,7 @@ def test_aggregation_query_in_transaction(
aggregation_type,
aggregation_args,
expected,
verify_pipeline,
):
"""
Test creating an aggregation query inside a transaction
Expand Down Expand Up @@ -3593,7 +3616,7 @@ def in_transaction(transaction):


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_or_query_in_transaction(client, cleanup, database):
def test_or_query_in_transaction(client, cleanup, database, verify_pipeline):
"""
Test running or query inside a transaction. Should pass transaction id along with request
"""
Expand Down
Loading
Loading