Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,59 @@ def distinct(self, *fields: str | Selectable) -> "_BasePipeline":
A new Pipeline object with this stage appended to the stage list
"""
return self._append(stages.Distinct(*fields))

def delete(self) -> "_BasePipeline":
"""
Deletes the documents from the current pipeline stage.

Example:
>>> from google.cloud.firestore_v1.pipeline_expressions import Field
>>> pipeline = client.pipeline().collection("logs")
>>> # Delete all documents in the "logs" collection where "status" is "archived"
>>> pipeline = pipeline.where(Field.of("status").equal("archived")).delete()
>>> pipeline.execute()

Returns:
A new Pipeline object with this stage appended to the stage list
"""
return self._append(stages.Delete())

def update(self, *transformed_fields: "Selectable") -> "_BasePipeline":
"""
Performs an update operation using documents from previous stages.

If called without `transformed_fields`, this method updates the documents in
place based on the data flowing through the pipeline.

To update specific fields with new values, provide `Selectable` expressions that define the
transformations to apply.

Example 1: Update a collection's schema by adding a new field and removing an old one.
>>> from google.cloud.firestore_v1.pipeline_expressions import Constant
>>> pipeline = client.pipeline().collection("books")
>>> pipeline = pipeline.add_fields(Constant.of("Fiction").as_("genre"))
>>> pipeline = pipeline.remove_fields("old_genre").update()
>>> pipeline.execute()

Example 2: Update documents in place with data from literals.
>>> pipeline = client.pipeline().literals(
... {"__name__": client.collection("books").document("book1"), "status": "Updated"}
... ).update()
>>> pipeline.execute()

Example 3: Update documents from previous stages with specified transformations.
>>> from google.cloud.firestore_v1.pipeline_expressions import Field, Constant
>>> pipeline = client.pipeline().collection("books")
>>> # Update the "status" field to "Discounted" for all books where price > 50
>>> pipeline = pipeline.where(Field.of("price").greater_than(50))
>>> pipeline = pipeline.update(Constant.of("Discounted").as_("status"))
>>> pipeline.execute()

Args:
*transformed_fields: Optional. The transformations to apply. If not provided,
the update is performed in place based on the data flowing through the pipeline.

Returns:
A new Pipeline object with this stage appended to the stage list
"""
return self._append(stages.Update(*transformed_fields))
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,24 @@ def __init__(self, condition: BooleanExpression):

def _pb_args(self):
return [self.condition._to_pb()]


class Delete(Stage):
"""Deletes documents matching the pipeline criteria."""

def __init__(self):
super().__init__("delete")

def _pb_args(self) -> list[Value]:
return []


class Update(Stage):
"""Updates documents with transformed fields."""

def __init__(self, *transformed_fields: Selectable):
super().__init__("update")
self.transformed_fields = list(transformed_fields)

def _pb_args(self) -> list[Value]:
return [Selectable._to_value(self.transformed_fields)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
data:
dml_delete_coll:
doc1: { score: 10 }
doc2: { score: 60 }
dml_update_coll:
doc1: { status: "pending", score: 50 }

tests:
- description: "Basic DML delete"
pipeline:
- Collection: dml_delete_coll
- Where:
FunctionExpression.less_than:
- Field: score
- Constant: 50
- Delete:
assert_end_state:
dml_delete_coll/doc1: null
dml_delete_coll/doc2: { score: 60 }
assert_proto:
pipeline:
stages:
- args:
- referenceValue: /dml_delete_coll
name: collection
- args:
- functionValue:
args:
- fieldReferenceValue: score
- integerValue: '50'
name: less_than
name: where
- name: delete

- description: "Basic DML update"
pipeline:
- Collection: dml_update_coll
- Update:
- AliasedExpression:
- Constant: "active"
- "status"
assert_end_state:
dml_update_coll/doc1: { status: "active", score: 50 }
assert_proto:
pipeline:
stages:
- args:
- referenceValue: /dml_update_coll
name: collection
- args:
- mapValue:
fields:
status:
stringValue: active
name: update
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_pipeline_expected_errors(test_dict, client):
if "assert_results" in t
or "assert_count" in t
or "assert_results_approximate" in t
or "assert_end_state" in t
],
ids=id_format,
)
Expand All @@ -131,6 +132,7 @@ def test_pipeline_results(test_dict, client):
test_dict.get("assert_results_approximate", None)
)
expected_count = test_dict.get("assert_count", None)
expected_end_state = _parse_yaml_types(test_dict.get("assert_end_state", {}))
pipeline = parse_pipeline(client, test_dict["pipeline"])
# check if server responds as expected
got_results = [snapshot.data() for snapshot in pipeline.stream()]
Expand All @@ -146,6 +148,19 @@ def test_pipeline_results(test_dict, client):
)
if expected_count is not None:
assert len(got_results) == expected_count
if expected_end_state:
for doc_path, expected_content in expected_end_state.items():
doc_ref = client.document(doc_path)
snapshot = doc_ref.get()
if expected_content is None:
assert not snapshot.exists, (
f"Expected {doc_path} to be absent, but it exists"
)
else:
assert snapshot.exists, (
f"Expected {doc_path} to exist, but it was absent"
)
assert snapshot.to_dict() == expected_content


@pytest.mark.parametrize(
Expand Down Expand Up @@ -176,6 +191,7 @@ async def test_pipeline_expected_errors_async(test_dict, async_client):
if "assert_results" in t
or "assert_count" in t
or "assert_results_approximate" in t
or "assert_end_state" in t
],
ids=id_format,
)
Expand All @@ -189,6 +205,7 @@ async def test_pipeline_results_async(test_dict, async_client):
test_dict.get("assert_results_approximate", None)
)
expected_count = test_dict.get("assert_count", None)
expected_end_state = _parse_yaml_types(test_dict.get("assert_end_state", {}))
pipeline = parse_pipeline(async_client, test_dict["pipeline"])
# check if server responds as expected
got_results = [snapshot.data() async for snapshot in pipeline.stream()]
Expand All @@ -204,6 +221,19 @@ async def test_pipeline_results_async(test_dict, async_client):
)
if expected_count is not None:
assert len(got_results) == expected_count
if expected_end_state:
for doc_path, expected_content in expected_end_state.items():
doc_ref = async_client.document(doc_path)
snapshot = await doc_ref.get()
if expected_content is None:
assert not snapshot.exists, (
f"Expected {doc_path} to be absent, but it exists"
)
else:
assert snapshot.exists, (
f"Expected {doc_path} to exist, but it was absent"
)
assert snapshot.to_dict() == expected_content


#################################################################################
Expand All @@ -223,7 +253,12 @@ def parse_pipeline(client, pipeline: list[dict[str, Any], str]):
# find arguments if given
if isinstance(stage, dict):
stage_yaml_args = stage[stage_name]
stage_obj = _apply_yaml_args_to_callable(stage_cls, client, stage_yaml_args)
if stage_yaml_args is None:
stage_obj = stage_cls()
else:
stage_obj = _apply_yaml_args_to_callable(
stage_cls, client, stage_yaml_args
)
else:
# yaml has no arguments
stage_obj = stage_cls()
Expand Down Expand Up @@ -279,18 +314,19 @@ def _apply_yaml_args_to_callable(callable_obj, client, yaml_args):
Helper to instantiate a class with yaml arguments. The arguments will be applied
as positional or keyword arguments, based on type
"""
if isinstance(yaml_args, dict):
return callable_obj(**_parse_expressions(client, yaml_args))
parsed = _parse_expressions(client, yaml_args)
if isinstance(yaml_args, dict) and isinstance(parsed, dict):
return callable_obj(**parsed)
elif isinstance(yaml_args, list) and not (
callable_obj == expr.Constant
or callable_obj == Vector
or callable_obj == expr.Array
):
# yaml has an array of arguments. Treat as args
return callable_obj(*_parse_expressions(client, yaml_args))
return callable_obj(*parsed)
else:
# yaml has a single argument
return callable_obj(_parse_expressions(client, yaml_args))
return callable_obj(parsed)


def _is_expr_string(yaml_str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -960,3 +960,39 @@ def test_to_pb(self):
assert got_fn.args[0].field_reference_value == "city"
assert got_fn.args[1].string_value == "SF"
assert len(result.options) == 0


class TestDelete:
def _make_one(self):
return stages.Delete()

def test_to_pb(self):
instance = self._make_one()
result = instance._to_pb()
assert result.name == "delete"
assert len(result.args) == 0
assert len(result.options) == 0


class TestUpdate:
def _make_one(self, *args):
return stages.Update(*args)

def test_to_pb_empty(self):
instance = self._make_one()
result = instance._to_pb()
assert result.name == "update"
assert len(result.args) == 1
assert result.args[0].map_value.fields == {}
assert len(result.options) == 0

def test_to_pb_with_fields(self):
instance = self._make_one(
Field.of("score").add(10).as_("score"), Constant.of("active").as_("status")
)
result = instance._to_pb()
assert result.name == "update"
assert len(result.args) == 1
assert "score" in result.args[0].map_value.fields
assert "status" in result.args[0].map_value.fields
assert len(result.options) == 0
Loading