diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py index fac7f8bc4bce..1e1b7477bd92 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py @@ -610,3 +610,59 @@ def distinct(self, *fields: str | Selectable) -> "_BasePipeline": A new Pipeline object with this stage appended to the stage list """ return self._append(stages.Distinct(*fields)) + + def delete(self) -> "_BasePipeline": + """ + Deletes the documents from the current pipeline stage. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("logs") + >>> # Delete all documents in the "logs" collection where "status" is "archived" + >>> pipeline = pipeline.where(Field.of("status").equal("archived")).delete() + >>> pipeline.execute() + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Delete()) + + def update(self, *transformed_fields: "Selectable") -> "_BasePipeline": + """ + Performs an update operation using documents from previous stages. + + If called without `transformed_fields`, this method updates the documents in + place based on the data flowing through the pipeline. + + To update specific fields with new values, provide `Selectable` expressions that define the + transformations to apply. + + Example 1: Update a collection's schema by adding a new field and removing an old one. + >>> from google.cloud.firestore_v1.pipeline_expressions import Constant + >>> pipeline = client.pipeline().collection("books") + >>> pipeline = pipeline.add_fields(Constant.of("Fiction").as_("genre")) + >>> pipeline = pipeline.remove_fields("old_genre").update() + >>> pipeline.execute() + + Example 2: Update documents in place with data from literals. + >>> pipeline = client.pipeline().literals( + ... {"__name__": client.collection("books").document("book1"), "status": "Updated"} + ... ).update() + >>> pipeline.execute() + + Example 3: Update documents from previous stages with specified transformations. + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, Constant + >>> pipeline = client.pipeline().collection("books") + >>> # Update the "status" field to "Discounted" for all books where price > 50 + >>> pipeline = pipeline.where(Field.of("price").greater_than(50)) + >>> pipeline = pipeline.update(Constant.of("Discounted").as_("status")) + >>> pipeline.execute() + + Args: + *transformed_fields: Optional. The transformations to apply. If not provided, + the update is performed in place based on the data flowing through the pipeline. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Update(*transformed_fields)) diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py index cac9c70d4b99..9de782f3cbfa 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py @@ -494,3 +494,24 @@ def __init__(self, condition: BooleanExpression): def _pb_args(self): return [self.condition._to_pb()] + + +class Delete(Stage): + """Deletes documents matching the pipeline criteria.""" + + def __init__(self): + super().__init__("delete") + + def _pb_args(self) -> list[Value]: + return [] + + +class Update(Stage): + """Updates documents with transformed fields.""" + + def __init__(self, *transformed_fields: Selectable): + super().__init__("update") + self.transformed_fields = list(transformed_fields) + + def _pb_args(self) -> list[Value]: + return [Selectable._to_value(self.transformed_fields)] diff --git a/packages/google-cloud-firestore/tests/system/pipeline_e2e/dml.yaml b/packages/google-cloud-firestore/tests/system/pipeline_e2e/dml.yaml new file mode 100644 index 000000000000..578ddef20492 --- /dev/null +++ b/packages/google-cloud-firestore/tests/system/pipeline_e2e/dml.yaml @@ -0,0 +1,55 @@ +data: + dml_delete_coll: + doc1: { score: 10 } + doc2: { score: 60 } + dml_update_coll: + doc1: { status: "pending", score: 50 } + +tests: + - description: "Basic DML delete" + pipeline: + - Collection: dml_delete_coll + - Where: + FunctionExpression.less_than: + - Field: score + - Constant: 50 + - Delete: + assert_end_state: + dml_delete_coll/doc1: null + dml_delete_coll/doc2: { score: 60 } + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /dml_delete_coll + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: score + - integerValue: '50' + name: less_than + name: where + - name: delete + + - description: "Basic DML update" + pipeline: + - Collection: dml_update_coll + - Update: + - AliasedExpression: + - Constant: "active" + - "status" + assert_end_state: + dml_update_coll/doc1: { status: "active", score: 50 } + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /dml_update_coll + name: collection + - args: + - mapValue: + fields: + status: + stringValue: active + name: update diff --git a/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py b/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py index afff43ac6950..289ad165f7db 100644 --- a/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py +++ b/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py @@ -119,6 +119,7 @@ def test_pipeline_expected_errors(test_dict, client): if "assert_results" in t or "assert_count" in t or "assert_results_approximate" in t + or "assert_end_state" in t ], ids=id_format, ) @@ -131,6 +132,7 @@ def test_pipeline_results(test_dict, client): test_dict.get("assert_results_approximate", None) ) expected_count = test_dict.get("assert_count", None) + expected_end_state = _parse_yaml_types(test_dict.get("assert_end_state", {})) pipeline = parse_pipeline(client, test_dict["pipeline"]) # check if server responds as expected got_results = [snapshot.data() for snapshot in pipeline.stream()] @@ -146,6 +148,19 @@ def test_pipeline_results(test_dict, client): ) if expected_count is not None: assert len(got_results) == expected_count + if expected_end_state: + for doc_path, expected_content in expected_end_state.items(): + doc_ref = client.document(doc_path) + snapshot = doc_ref.get() + if expected_content is None: + assert not snapshot.exists, ( + f"Expected {doc_path} to be absent, but it exists" + ) + else: + assert snapshot.exists, ( + f"Expected {doc_path} to exist, but it was absent" + ) + assert snapshot.to_dict() == expected_content @pytest.mark.parametrize( @@ -176,6 +191,7 @@ async def test_pipeline_expected_errors_async(test_dict, async_client): if "assert_results" in t or "assert_count" in t or "assert_results_approximate" in t + or "assert_end_state" in t ], ids=id_format, ) @@ -189,6 +205,7 @@ async def test_pipeline_results_async(test_dict, async_client): test_dict.get("assert_results_approximate", None) ) expected_count = test_dict.get("assert_count", None) + expected_end_state = _parse_yaml_types(test_dict.get("assert_end_state", {})) pipeline = parse_pipeline(async_client, test_dict["pipeline"]) # check if server responds as expected got_results = [snapshot.data() async for snapshot in pipeline.stream()] @@ -204,6 +221,19 @@ async def test_pipeline_results_async(test_dict, async_client): ) if expected_count is not None: assert len(got_results) == expected_count + if expected_end_state: + for doc_path, expected_content in expected_end_state.items(): + doc_ref = async_client.document(doc_path) + snapshot = await doc_ref.get() + if expected_content is None: + assert not snapshot.exists, ( + f"Expected {doc_path} to be absent, but it exists" + ) + else: + assert snapshot.exists, ( + f"Expected {doc_path} to exist, but it was absent" + ) + assert snapshot.to_dict() == expected_content ################################################################################# @@ -223,7 +253,12 @@ def parse_pipeline(client, pipeline: list[dict[str, Any], str]): # find arguments if given if isinstance(stage, dict): stage_yaml_args = stage[stage_name] - stage_obj = _apply_yaml_args_to_callable(stage_cls, client, stage_yaml_args) + if stage_yaml_args is None: + stage_obj = stage_cls() + else: + stage_obj = _apply_yaml_args_to_callable( + stage_cls, client, stage_yaml_args + ) else: # yaml has no arguments stage_obj = stage_cls() @@ -279,18 +314,19 @@ def _apply_yaml_args_to_callable(callable_obj, client, yaml_args): Helper to instantiate a class with yaml arguments. The arguments will be applied as positional or keyword arguments, based on type """ - if isinstance(yaml_args, dict): - return callable_obj(**_parse_expressions(client, yaml_args)) + parsed = _parse_expressions(client, yaml_args) + if isinstance(yaml_args, dict) and isinstance(parsed, dict): + return callable_obj(**parsed) elif isinstance(yaml_args, list) and not ( callable_obj == expr.Constant or callable_obj == Vector or callable_obj == expr.Array ): # yaml has an array of arguments. Treat as args - return callable_obj(*_parse_expressions(client, yaml_args)) + return callable_obj(*parsed) else: # yaml has a single argument - return callable_obj(_parse_expressions(client, yaml_args)) + return callable_obj(parsed) def _is_expr_string(yaml_str): diff --git a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py index 65685e6e33d6..b9ab603b713b 100644 --- a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py +++ b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py @@ -960,3 +960,39 @@ def test_to_pb(self): assert got_fn.args[0].field_reference_value == "city" assert got_fn.args[1].string_value == "SF" assert len(result.options) == 0 + + +class TestDelete: + def _make_one(self): + return stages.Delete() + + def test_to_pb(self): + instance = self._make_one() + result = instance._to_pb() + assert result.name == "delete" + assert len(result.args) == 0 + assert len(result.options) == 0 + + +class TestUpdate: + def _make_one(self, *args): + return stages.Update(*args) + + def test_to_pb_empty(self): + instance = self._make_one() + result = instance._to_pb() + assert result.name == "update" + assert len(result.args) == 1 + assert result.args[0].map_value.fields == {} + assert len(result.options) == 0 + + def test_to_pb_with_fields(self): + instance = self._make_one( + Field.of("score").add(10).as_("score"), Constant.of("active").as_("status") + ) + result = instance._to_pb() + assert result.name == "update" + assert len(result.args) == 1 + assert "score" in result.args[0].map_value.fields + assert "status" in result.args[0].map_value.fields + assert len(result.options) == 0