Skip to content

Commit 7e9f13a

Browse files
authored
Add return-type to public functions, mostly tests part 5 (#7691)
No change in the effective code. A batch of ~50 files. Modified files pass ruff check --select=ANN201 Notable changes: - Add no-op method `Operation.__pow__` to support static type check - Replace `list[Qid]` --> `Sequence[Qid]` which is covariant - Make `KakDecomposition` and `FSimGate` match the `SupportsUnitary` type - Loosen return type of the `_kraus_` method from `Sequence` to `Iterable` Partially implements #4393
1 parent 913e68b commit 7e9f13a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+378
-345
lines changed

cirq-core/cirq/_doc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515

1616
from __future__ import annotations
1717

18-
from typing import Any
18+
from typing import TypeVar
1919

20+
T = TypeVar('T')
2021
RECORDED_CONST_DOCS: dict[int, str] = {}
2122

2223

23-
def document(value: Any, doc_string: str = ''):
24+
def document(value: T, doc_string: str = '') -> T:
2425
"""Stores documentation details about the given value.
2526
2627
This method is used to associate a docstring with global constants. It is
@@ -64,7 +65,7 @@ def document(value: Any, doc_string: str = ''):
6465
_DOC_PRIVATE = "_tf_docs_doc_private"
6566

6667

67-
def doc_private(obj):
68+
def doc_private(obj: T) -> T:
6869
"""A decorator: Generates docs for private methods/functions.
6970
7071
For example:

cirq-core/cirq/linalg/decompositions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,9 @@ def __repr__(self) -> str:
472472
f' global_phase={self.global_phase!r})'
473473
)
474474

475+
def _has_unitary_(self) -> bool:
476+
return True
477+
475478
def _unitary_(self) -> np.ndarray:
476479
"""Returns the decomposition's two-qubit unitary matrix.
477480

cirq-core/cirq/linalg/decompositions_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def test_kak_plot_empty() -> None:
259259
)
260260
def test_kak_decomposition(target) -> None:
261261
kak = cirq.kak_decomposition(target)
262+
assert cirq.has_unitary(kak)
262263
np.testing.assert_allclose(cirq.unitary(kak), target, atol=1e-8)
263264

264265

cirq-core/cirq/ops/fourier_transform_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ def test_qft() -> None:
107107

108108
arr = np.array([[1, 1, 1, 1], [1, -1j, -1, 1j], [1, -1, 1, -1], [1, 1j, -1, -1j]]) / 2
109109
np.testing.assert_allclose(
110-
cirq.unitary(cirq.qft(*cirq.LineQubit.range(2)) ** -1), # type: ignore[operator]
111-
arr, # type: ignore[arg-type]
112-
atol=1e-8,
110+
cirq.unitary(cirq.qft(*cirq.LineQubit.range(2)) ** -1), arr, atol=1e-8
113111
)
114112

115113
for k in range(4):
@@ -121,7 +119,7 @@ def test_qft() -> None:
121119

122120
def test_inverse() -> None:
123121
a, b, c = cirq.LineQubit.range(3)
124-
assert cirq.qft(a, b, c, inverse=True) == cirq.qft(a, b, c) ** -1 # type: ignore[operator]
122+
assert cirq.qft(a, b, c, inverse=True) == cirq.qft(a, b, c) ** -1
125123
assert cirq.qft(a, b, c, inverse=True, without_reverse=True) == cirq.inverse(
126124
cirq.qft(a, b, c, without_reverse=True)
127125
)

cirq-core/cirq/ops/raw_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,9 @@ def _num_qubits_(self) -> int:
528528
def _qid_shape_(self) -> tuple[int, ...]:
529529
return protocols.qid_shape(self.qubits)
530530

531+
def __pow__(self, exponent: Any) -> Operation:
532+
return NotImplemented # pragma: no cover
533+
531534
@abc.abstractmethod
532535
def with_qubits(self, *new_qubits: cirq.Qid) -> cirq.Operation:
533536
"""Returns the same operation, but applied to different qubits.

cirq-core/cirq/protocols/apply_unitary_protocol_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def assert_is_swap_simple(val: cirq.SupportsConsistentApplyUnitary) -> None:
240240
op_indices, tuple(qid_shape[i] for i in op_indices)
241241
)
242242
sub_result = val._apply_unitary_(sub_args)
243+
assert isinstance(sub_result, np.ndarray)
243244
result = _incorporate_result_into_target(args, sub_args, sub_result)
244245
np.testing.assert_allclose(result, expected, atol=1e-8)
245246

@@ -258,6 +259,7 @@ def assert_is_swap(val: cirq.SupportsConsistentApplyUnitary) -> None:
258259
op_indices, tuple(qid_shape[i] for i in op_indices)
259260
)
260261
sub_result = val._apply_unitary_(sub_args)
262+
assert isinstance(sub_result, np.ndarray)
261263
result = _incorporate_result_into_target(args, sub_args, sub_result)
262264
np.testing.assert_allclose(result, expected, atol=1e-8, verbose=True)
263265

cirq-core/cirq/protocols/kraus_protocol.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import warnings
2020
from types import NotImplementedType
21-
from typing import Any, Protocol, Sequence, TypeVar
21+
from typing import Any, Iterable, Protocol, TypeVar
2222

2323
import numpy as np
2424

@@ -31,7 +31,7 @@
3131

3232
# This is a special indicator value used by the channel method to determine
3333
# whether or not the caller provided a 'default' argument. It must be of type
34-
# Sequence[np.ndarray] to ensure the method has the correct type signature in
34+
# Iterable[np.ndarray] to ensure the method has the correct type signature in
3535
# that case. It is checked for using `is`, so it won't have a false positive
3636
# if the user provides a different (np.array([]),) value.
3737
RaiseTypeErrorIfNotProvided: tuple[np.ndarray] = (np.array([]),)
@@ -44,7 +44,7 @@ class SupportsKraus(Protocol):
4444
"""An object that may be describable as a quantum channel."""
4545

4646
@doc_private
47-
def _kraus_(self) -> Sequence[np.ndarray] | NotImplementedType:
47+
def _kraus_(self) -> Iterable[np.ndarray] | NotImplementedType:
4848
r"""A list of Kraus matrices describing the quantum channel.
4949
5050
These matrices are the terms in the operator sum representation of a

cirq-core/cirq/protocols/kraus_protocol_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import Iterable, Sequence
19+
from typing import Iterable
2020

2121
import numpy as np
2222
import pytest
@@ -89,7 +89,7 @@ def test_explicit_kraus() -> None:
8989
c = (a0, a1)
9090

9191
class ReturnsKraus:
92-
def _kraus_(self) -> Sequence[np.ndarray]:
92+
def _kraus_(self) -> Iterable[np.ndarray]:
9393
return c
9494

9595
assert cirq.kraus(ReturnsKraus()) is c

cirq-core/cirq/protocols/unitary_protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class SupportsUnitary(Protocol):
3838
"""An object that may be describable by a unitary matrix."""
3939

4040
@doc_private
41-
def _unitary_(self) -> np.ndarray | NotImplementedType:
41+
def _unitary_(self) -> np.ndarray | NotImplementedType | None:
4242
"""A unitary matrix describing this value, e.g. the matrix of a gate.
4343
4444
This method is used by the global `cirq.unitary` method. If this method

cirq-core/cirq/testing/consistent_resolve_parameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import cirq
2222

2323

24-
def assert_consistent_resolve_parameters(val: Any):
24+
def assert_consistent_resolve_parameters(val: Any) -> None:
2525
names = cirq.parameter_names(val)
2626
symbols = cirq.parameter_symbols(val)
2727

0 commit comments

Comments
 (0)