diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/_helpers.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/_helpers.py index 04cdd05f0c45..23035a5f84aa 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/_helpers.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/_helpers.py @@ -1133,7 +1133,7 @@ def modify_write(self, write, *unused_args, **unused_kwargs) -> None: def make_retry_timeout_kwargs( retry: retries.Retry | retries.AsyncRetry | object | None, timeout: float | None -) -> dict: +) -> dict[str, Any]: """Helper fo API methods which take optional 'retry' / 'timeout' args.""" kwargs = {} diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/async_document.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/async_document.py index 85d4dc3bae19..3c3b311f1e27 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/async_document.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/async_document.py @@ -18,7 +18,7 @@ import datetime import logging -from typing import AsyncGenerator, Iterable +from typing import TYPE_CHECKING, AsyncGenerator, Iterable from google.api_core import gapic_v1 from google.api_core import retry_async as retries @@ -33,10 +33,16 @@ ) from google.cloud.firestore_v1.types import write +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.async_client import AsyncClient +else: + AsyncClient = None + + logger = logging.getLogger(__name__) -class AsyncDocumentReference(BaseDocumentReference): +class AsyncDocumentReference(BaseDocumentReference[AsyncClient]): """A reference to a document in a Firestore database. The document may already exist or can be created by this class. @@ -317,6 +323,8 @@ async def delete( """ request, kwargs = self._prep_delete(option, retry, timeout) + if self._client is None: + raise ValueError("A deletion requires a `client`.") commit_response = await self._client._firestore_api.commit( request=request, metadata=self._client._rpc_metadata, @@ -374,6 +382,8 @@ async def get( field_paths, transaction, retry, timeout, read_time ) + if self._client is None: + raise ValueError("A get requires a `client`.") response_iter = await self._client._firestore_api.batch_get_documents( request=request, metadata=self._client._rpc_metadata, @@ -433,6 +443,8 @@ async def collections( """ request, kwargs = self._prep_collections(page_size, retry, timeout, read_time) + if self._client is None: + raise ValueError("A collection reference requires a `client`.") iterator = await self._client._firestore_api.list_collection_ids( request=request, metadata=self._client._rpc_metadata, diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/base_document.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/base_document.py index 92d8daa21fd6..84e5dfb27860 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/base_document.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/base_document.py @@ -22,10 +22,11 @@ Any, Awaitable, Dict, + Generic, Iterable, Optional, - Tuple, - Union, + TypeVar, + overload, ) from google.api_core import retry as retries @@ -34,20 +35,36 @@ from google.cloud.firestore_v1 import field_path as field_path_module from google.cloud.firestore_v1.types import common -# Types needed only for Type Hints if TYPE_CHECKING: # pragma: NO COVER import datetime + from collections.abc import AsyncIterable, Callable - from google.cloud.firestore_v1.types import Document, firestore, write + from google.protobuf.timestamp_pb2 import Timestamp + from google.cloud.firestore_v1.async_batch import AsyncWriteBatch + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.batch import WriteBatch + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.types import ( + Document, + firestore, + write, + ) + from google.cloud.firestore_v1.watch import Watch -class BaseDocumentReference(object): + Self = TypeVar("Self", bound="BaseDocumentReference") + C = TypeVar("C", AsyncClient, Client) +else: + C = TypeVar("C") + + +class BaseDocumentReference(Generic[C]): """A reference to a document in a Firestore database. The document may already exist or can be created by this class. Args: - path (Tuple[str, ...]): The components in the document path. + path (tuple[str, ...]): The components in the document path. This is a series of strings representing each collection and sub-collection ID, as well as the document IDs for any documents that contain a sub-collection (as well as the base document). @@ -68,16 +85,16 @@ class BaseDocumentReference(object): _document_path_internal = None - def __init__(self, *path, **kwargs) -> None: + def __init__(self, *path: str, **kwargs: C | None) -> None: _helpers.verify_path(path, is_collection=False) self._path = path - self._client = kwargs.pop("client", None) + self._client: C | None = kwargs.pop("client", None) if kwargs: raise TypeError( "Received unexpected arguments", kwargs, "Only `client` is supported" ) - def __copy__(self): + def __copy__(self: Self) -> Self: """Shallow copy the instance. We leave the client "as-is" but tuple-unpack the path. @@ -89,7 +106,7 @@ def __copy__(self): result._document_path_internal = self._document_path_internal return result - def __deepcopy__(self, unused_memo): + def __deepcopy__(self: Self, unused_memo: object) -> Self: """Deep copy the instance. This isn't a true deep copy, wee leave the client "as-is" but @@ -100,14 +117,14 @@ def __deepcopy__(self, unused_memo): """ return self.__copy__() - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """Equality check against another instance. Args: - other (Any): A value to compare against. + other (object): A value to compare against. Returns: - Union[bool, NotImplementedType]: Indicating if the values are + bool | NotImplementedType: Indicating if the values are equal. """ if isinstance(other, self.__class__): @@ -115,17 +132,17 @@ def __eq__(self, other): else: return NotImplemented - def __hash__(self): - return hash(self._path) + hash(self._client) + def __hash__(self) -> int: + return hash((self._path, self._client)) - def __ne__(self, other): + def __ne__(self, other) -> bool: """Inequality check against another instance. Args: - other (Any): A value to compare against. + other (object): A value to compare against. Returns: - Union[bool, NotImplementedType]: Indicating if the values are + bool | NotImplementedType: Indicating if the values are not equal. """ if isinstance(other, self.__class__): @@ -134,7 +151,7 @@ def __ne__(self, other): return NotImplemented @property - def path(self): + def path(self) -> str: """Database-relative for this document. Returns: @@ -143,7 +160,7 @@ def path(self): return "/".join(self._path) @property - def _document_path(self): + def _document_path(self) -> str: """Create and cache the full path for this document. Of the form: @@ -165,7 +182,7 @@ def _document_path(self): return self._document_path_internal @property - def id(self): + def id(self) -> str: """The document identifier (within its collection). Returns: @@ -174,7 +191,7 @@ def id(self): return self._path[-1] @property - def parent(self): + def parent(self: Self) -> Self: """Collection that owns the current document. Returns: @@ -182,9 +199,11 @@ def parent(self): The parent collection. """ parent_path = self._path[:-1] + if self._client is None: + raise ValueError("A collection reference requires a `client`.") return self._client.collection(*parent_path) - def collection(self, collection_id: str): + def collection(self: Self, collection_id: str) -> Self: """Create a sub-collection underneath the current document. Args: @@ -196,14 +215,36 @@ def collection(self, collection_id: str): The child collection. """ child_path = self._path + (collection_id,) + if self._client is None: + raise ValueError("A collection reference requires a `client`.") return self._client.collection(*child_path) + @overload + def _prep_create( + self: "BaseDocumentReference[AsyncClient]", + document_data: dict, + retry: retries.AsyncRetry | None | object = None, + timeout: float | None = None, + ) -> tuple[AsyncWriteBatch, dict[str, Any]]: + pass + + @overload + def _prep_create( + self: "BaseDocumentReference[Client]", + document_data: dict, + retry: retries.Retry | None | object = None, + timeout: float | None = None, + ) -> tuple[WriteBatch, dict[str, Any]]: + pass + def _prep_create( self, document_data: dict, retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, - ) -> Tuple[Any, dict]: + ) -> tuple[AsyncWriteBatch | WriteBatch, dict[str, Any]]: + if self._client is None: + raise ValueError("A batch requires a `client`.") batch = self._client.batch() batch.create(self, document_data) kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) @@ -218,13 +259,35 @@ def create( ) -> write.WriteResult | Awaitable[write.WriteResult]: raise NotImplementedError + @overload + def _prep_set( + self: "BaseDocumentReference[AsyncClient]", + document_data: dict, + merge: bool = False, + retry: retries.AsyncRetry | None | object = None, + timeout: float | None = None, + ) -> tuple[AsyncWriteBatch, dict[str, Any]]: + pass + + @overload + def _prep_set( + self: "BaseDocumentReference[Client]", + document_data: dict, + merge: bool = False, + retry: retries.Retry | None | object = None, + timeout: float | None = None, + ) -> tuple[WriteBatch, dict[str, Any]]: + pass + def _prep_set( self, document_data: dict, merge: bool = False, retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, - ) -> Tuple[Any, dict]: + ) -> tuple[AsyncWriteBatch | WriteBatch, dict[str, Any]]: + if self._client is None: + raise ValueError("A batch requires a `client`.") batch = self._client.batch() batch.set(self, document_data, merge=merge) kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) @@ -237,16 +300,38 @@ def set( merge: bool = False, retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, - ): + ) -> write.WriteResult | Awaitable[write.WriteResult]: raise NotImplementedError + @overload + def _prep_update( + self: "BaseDocumentReference[AsyncClient]", + field_updates: dict, + option: _helpers.WriteOption | None = None, + retry: retries.AsyncRetry | None | object = None, + timeout: float | None = None, + ) -> tuple[AsyncWriteBatch, dict[str, Any]]: + pass + + @overload + def _prep_update( + self: "BaseDocumentReference[Client]", + field_updates: dict, + option: _helpers.WriteOption | None = None, + retry: retries.Retry | None | object = None, + timeout: float | None = None, + ) -> tuple[WriteBatch, dict[str, Any]]: + pass + def _prep_update( self, field_updates: dict, option: _helpers.WriteOption | None = None, retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, - ) -> Tuple[Any, dict]: + ) -> tuple[AsyncWriteBatch | WriteBatch, dict[str, Any]]: + if self._client is None: + raise ValueError("A batch requires a `client`.") batch = self._client.batch() batch.update(self, field_updates, option=option) kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) @@ -259,7 +344,7 @@ def update( option: _helpers.WriteOption | None = None, retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, - ): + ) -> write.WriteResult | Awaitable[write.WriteResult]: raise NotImplementedError def _prep_delete( @@ -267,9 +352,11 @@ def _prep_delete( option: _helpers.WriteOption | None = None, retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, - ) -> Tuple[dict, dict]: + ) -> tuple[dict[str, object], dict[str, object]]: """Shared setup for async/sync :meth:`delete`.""" write_pb = _helpers.pb_for_delete(self._document_path, option) + if self._client is None: + raise ValueError("A deletion requires a `client`.") request = { "database": self._client._database_string, "writes": [write_pb], @@ -284,7 +371,7 @@ def delete( option: _helpers.WriteOption | None = None, retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, - ): + ) -> Timestamp | Awaitable[Timestamp]: raise NotImplementedError def _prep_batch_get( @@ -294,7 +381,7 @@ def _prep_batch_get( retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, read_time: datetime.datetime | None = None, - ) -> Tuple[dict, dict]: + ) -> tuple[dict, dict[str, Any]]: """Shared setup for async/sync :meth:`get`.""" if isinstance(field_paths, str): raise ValueError("'field_paths' must be a sequence of paths, not a string.") @@ -304,6 +391,8 @@ def _prep_batch_get( else: mask = None + if self._client is None: + raise ValueError("A get requires a `client`.") request = { "database": self._client._database_string, "documents": [self._document_path], @@ -333,9 +422,9 @@ def _prep_collections( retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, read_time: datetime.datetime | None = None, - ) -> Tuple[dict, dict]: + ) -> tuple[dict, dict]: """Shared setup for async/sync :meth:`collections`.""" - request = { + request: dict[str, str | int | None | datetime.datetime] = { "parent": self._document_path, "page_size": page_size, } @@ -346,16 +435,16 @@ def _prep_collections( return request, kwargs def collections( - self, + self: Self, page_size: int | None = None, retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, *, read_time: datetime.datetime | None = None, - ): + ) -> Iterable[Self] | AsyncIterable[Self]: raise NotImplementedError - def on_snapshot(self, callback): + def on_snapshot(self, callback: Callable[[DocumentSnapshot], None]) -> Watch: raise NotImplementedError @@ -512,7 +601,7 @@ def get(self, field_path: str) -> Any: nested_data = field_path_module.get_nested_value(field_path, self._data) return copy.deepcopy(nested_data) - def to_dict(self) -> Union[Dict[str, Any], None]: + def to_dict(self) -> Dict[str, Any] | None: """Retrieve the data contained in this snapshot. A copy is returned since the data may contain mutable values, @@ -531,7 +620,7 @@ def _to_protobuf(self) -> Optional[Document]: return _helpers.document_snapshot_to_protobuf(self) -def _get_document_path(client, path: Tuple[str]) -> str: +def _get_document_path(client: C, path: tuple[str, ...]) -> str: """Convert a path tuple into a full path string. Of the form: @@ -543,7 +632,7 @@ def _get_document_path(client, path: Tuple[str]) -> str: client (:class:`~google.cloud.firestore_v1.client.Client`): The client that holds configuration details and a GAPIC client object. - path (Tuple[str, ...]): The components in a document path. + path (tuple[str, ...]): The components in a document path. Returns: str: The fully-qualified document path. diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/bulk_writer.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/bulk_writer.py index 141bc7aa6cda..c0498a24b10f 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/bulk_writer.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/bulk_writer.py @@ -566,7 +566,7 @@ def create( """ self._verify_not_closed() - if reference._document_path in self._operations_document_paths: + if reference in self._operations_document_paths: self._enqueue_current_batch() self._operations.append( @@ -576,7 +576,7 @@ def create( attempts=attempts, ), ) - self._operations_document_paths.append(reference._document_path) + self._operations_document_paths.append(reference) self._maybe_enqueue_current_batch() @@ -605,7 +605,7 @@ def delete( """ self._verify_not_closed() - if reference._document_path in self._operations_document_paths: + if reference in self._operations_document_paths: self._enqueue_current_batch() self._operations.append( @@ -615,7 +615,7 @@ def delete( attempts=attempts, ), ) - self._operations_document_paths.append(reference._document_path) + self._operations_document_paths.append(reference) self._maybe_enqueue_current_batch() @@ -648,7 +648,7 @@ def set( """ self._verify_not_closed() - if reference._document_path in self._operations_document_paths: + if reference in self._operations_document_paths: self._enqueue_current_batch() self._operations.append( @@ -659,7 +659,7 @@ def set( attempts=attempts, ) ) - self._operations_document_paths.append(reference._document_path) + self._operations_document_paths.append(reference) self._maybe_enqueue_current_batch() @@ -696,7 +696,7 @@ def update( self._verify_not_closed() - if reference._document_path in self._operations_document_paths: + if reference in self._operations_document_paths: self._enqueue_current_batch() self._operations.append( @@ -707,7 +707,7 @@ def update( attempts=attempts, ) ) - self._operations_document_paths.append(reference._document_path) + self._operations_document_paths.append(reference) self._maybe_enqueue_current_batch() diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/document.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/document.py index 11065ee8cb03..d4d831e0be4f 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/document.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/document.py @@ -18,7 +18,7 @@ import datetime import logging -from typing import Any, Callable, Generator, Iterable +from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable from google.api_core import gapic_v1 from google.api_core import retry as retries @@ -34,10 +34,15 @@ from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.watch import Watch +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.client import Client +else: + Client = None + logger = logging.getLogger(__name__) -class DocumentReference(BaseDocumentReference): +class DocumentReference(BaseDocumentReference[Client]): """A reference to a document in a Firestore database. The document may already exist or can be created by this class. @@ -353,6 +358,8 @@ def delete( """ request, kwargs = self._prep_delete(option, retry, timeout) + if self._client is None: + raise ValueError("A deletion requires a `client`.") commit_response = self._client._firestore_api.commit( request=request, metadata=self._client._rpc_metadata, @@ -410,6 +417,8 @@ def get( field_paths, transaction, retry, timeout, read_time ) + if self._client is None: + raise ValueError("A get requires a `client`.") response_iter = self._client._firestore_api.batch_get_documents( request=request, metadata=self._client._rpc_metadata, @@ -470,6 +479,8 @@ def collections( """ request, kwargs = self._prep_collections(page_size, retry, timeout, read_time) + if self._client is None: + raise ValueError("A collection reference requires a `client`.") iterator = self._client._firestore_api.list_collection_ids( request=request, metadata=self._client._rpc_metadata,