From 2d66897eac939cc02dd3083c8a1f5bc71126d891 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:10:55 +0100 Subject: [PATCH 01/22] Add ANN ruff rule --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 86727712..30a86f49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,7 @@ select = [ "I", # isort "UP", # pyupgrade "ARG", # flake8-unused-arguments + "ANN", # flake-8-annotations "B", # flake8-bugbear "C4", # flake8-comprehensions "FIX", # flake8-fixme From f27a630be972ad8e0746018d1bc4f770b49be5e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:16:22 +0100 Subject: [PATCH 02/22] Add missing -> None to __init__ --- src/torchjd/aggregation/_aggregator_bases.py | 6 +- src/torchjd/aggregation/_aligned_mtl.py | 4 +- src/torchjd/aggregation/_cagrad.py | 4 +- src/torchjd/aggregation/_config.py | 2 +- src/torchjd/aggregation/_constant.py | 4 +- src/torchjd/aggregation/_dualproj.py | 4 +- src/torchjd/aggregation/_flattening.py | 2 +- src/torchjd/aggregation/_graddrop.py | 2 +- src/torchjd/aggregation/_imtl_g.py | 2 +- src/torchjd/aggregation/_krum.py | 4 +- src/torchjd/aggregation/_mean.py | 2 +- src/torchjd/aggregation/_mgda.py | 4 +- src/torchjd/aggregation/_nash_mtl.py | 4 +- src/torchjd/aggregation/_pcgrad.py | 2 +- src/torchjd/aggregation/_random.py | 2 +- src/torchjd/aggregation/_sum.py | 2 +- src/torchjd/aggregation/_trimmed_mean.py | 2 +- src/torchjd/aggregation/_upgrad.py | 4 +- .../aggregation/_utils/non_differentiable.py | 2 +- src/torchjd/aggregation/_weighting_bases.py | 6 +- src/torchjd/autogram/_engine.py | 2 +- src/torchjd/autogram/_gramian_computer.py | 4 +- src/torchjd/autogram/_jacobian_computer.py | 2 +- src/torchjd/autogram/_module_hook_manager.py | 6 +- src/torchjd/autojac/_transform/_base.py | 4 +- .../autojac/_transform/_diagonalize.py | 2 +- .../autojac/_transform/_differentiate.py | 2 +- src/torchjd/autojac/_transform/_grad.py | 2 +- src/torchjd/autojac/_transform/_init.py | 2 +- src/torchjd/autojac/_transform/_jac.py | 2 +- .../autojac/_transform/_ordered_set.py | 2 +- src/torchjd/autojac/_transform/_select.py | 2 +- src/torchjd/autojac/_transform/_stack.py | 2 +- tests/doc/test_rst.py | 2 +- tests/plots/_utils.py | 2 +- tests/unit/aggregation/_matrix_samplers.py | 2 +- tests/unit/autojac/_transform/test_base.py | 2 +- tests/unit/autojac/_transform/test_stack.py | 2 +- tests/utils/architectures.py | 126 +++++++++--------- tests/utils/forward_backwards.py | 2 +- 40 files changed, 118 insertions(+), 118 deletions(-) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 78168eae..d4be05e9 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -13,7 +13,7 @@ class Aggregator(nn.Module, ABC): :math:`m \times n` into row vectors of dimension :math:`n`. """ - def __init__(self): + def __init__(self) -> None: super().__init__() @staticmethod @@ -48,7 +48,7 @@ class WeightedAggregator(Aggregator): :param weighting: The object responsible for extracting the vector of weights from the matrix. """ - def __init__(self, weighting: Weighting[Matrix]): + def __init__(self, weighting: Weighting[Matrix]) -> None: super().__init__() self.weighting = weighting @@ -77,6 +77,6 @@ class GramianWeightedAggregator(WeightedAggregator): gramian. """ - def __init__(self, gramian_weighting: Weighting[PSDMatrix]): + def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None: super().__init__(gramian_weighting << compute_gramian) self.gramian_weighting = gramian_weighting diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index fe807e0a..2230c27f 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -61,7 +61,7 @@ def __init__( self, pref_vector: Tensor | None = None, scale_mode: SUPPORTED_SCALE_MODE = "min", - ): + ) -> None: self._pref_vector = pref_vector self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode)) @@ -92,7 +92,7 @@ def __init__( self, pref_vector: Tensor | None = None, scale_mode: SUPPORTED_SCALE_MODE = "min", - ): + ) -> None: super().__init__() self._pref_vector = pref_vector self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index d29ca7b3..6731178b 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -34,7 +34,7 @@ class CAGrad(GramianWeightedAggregator): To install it, use ``pip install "torchjd[cagrad]"``. """ - def __init__(self, c: float, norm_eps: float = 0.0001): + def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) self._c = c self._norm_eps = norm_eps @@ -67,7 +67,7 @@ class CAGradWeighting(Weighting[PSDMatrix]): function. """ - def __init__(self, c: float, norm_eps: float = 0.0001): + def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__() if c < 0.0: diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 24866c41..447ccd3a 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -50,7 +50,7 @@ class ConFIG(Aggregator): `_. """ - def __init__(self, pref_vector: Tensor | None = None): + def __init__(self, pref_vector: Tensor | None = None) -> None: super().__init__() self.weighting = pref_vector_to_weighting(pref_vector, default=SumWeighting()) self._pref_vector = pref_vector diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 03b629ea..a547b813 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -15,7 +15,7 @@ class Constant(WeightedAggregator): :param weights: The weights associated to the rows of the input matrices. """ - def __init__(self, weights: Tensor): + def __init__(self, weights: Tensor) -> None: super().__init__(weighting=ConstantWeighting(weights=weights)) self._weights = weights @@ -35,7 +35,7 @@ class ConstantWeighting(Weighting[Matrix]): :param weights: The weights to return at each call. """ - def __init__(self, weights: Tensor): + def __init__(self, weights: Tensor) -> None: if weights.dim() != 1: raise ValueError( "Parameter `weights` should be a 1-dimensional tensor. Found `weights.shape = " diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index d7d88648..7e868f62 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -33,7 +33,7 @@ def __init__( norm_eps: float = 0.0001, reg_eps: float = 0.0001, solver: SUPPORTED_SOLVER = "quadprog", - ): + ) -> None: self._pref_vector = pref_vector self._norm_eps = norm_eps self._reg_eps = reg_eps @@ -77,7 +77,7 @@ def __init__( norm_eps: float = 0.0001, reg_eps: float = 0.0001, solver: SUPPORTED_SOLVER = "quadprog", - ): + ) -> None: super().__init__() self._pref_vector = pref_vector self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) diff --git a/src/torchjd/aggregation/_flattening.py b/src/torchjd/aggregation/_flattening.py index 208db3ec..15736b52 100644 --- a/src/torchjd/aggregation/_flattening.py +++ b/src/torchjd/aggregation/_flattening.py @@ -20,7 +20,7 @@ class Flattening(GeneralizedWeighting): :param weighting: The weighting to apply to the Gramian matrix. """ - def __init__(self, weighting: Weighting): + def __init__(self, weighting: Weighting) -> None: super().__init__() self.weighting = weighting diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index afa16451..61c9354e 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -26,7 +26,7 @@ class GradDrop(Aggregator): through. Defaults to None, which means no leak. """ - def __init__(self, f: Callable = _identity, leak: Tensor | None = None): + def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None: if leak is not None and leak.dim() != 1: raise ValueError( "Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = " diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index f45e8c2e..75d00b76 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -16,7 +16,7 @@ class IMTLG(GramianWeightedAggregator): `_, supports matrices with some linearly dependant rows. """ - def __init__(self): + def __init__(self) -> None: super().__init__(IMTLGWeighting()) # This prevents computing gradients that can be very wrong. diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index 93565174..40285d89 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -19,7 +19,7 @@ class Krum(GramianWeightedAggregator): :param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1. """ - def __init__(self, n_byzantine: int, n_selected: int = 1): + def __init__(self, n_byzantine: int, n_selected: int = 1) -> None: self._n_byzantine = n_byzantine self._n_selected = n_selected super().__init__(KrumWeighting(n_byzantine=n_byzantine, n_selected=n_selected)) @@ -44,7 +44,7 @@ class KrumWeighting(Weighting[PSDMatrix]): :param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1. """ - def __init__(self, n_byzantine: int, n_selected: int = 1): + def __init__(self, n_byzantine: int, n_selected: int = 1) -> None: super().__init__() if n_byzantine < 0: raise ValueError( diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index d7085e10..8fc5b057 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -13,7 +13,7 @@ class Mean(WeightedAggregator): matrices. """ - def __init__(self): + def __init__(self) -> None: super().__init__(weighting=MeanWeighting()) diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index 8f753c2a..aec32947 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -20,7 +20,7 @@ class MGDA(GramianWeightedAggregator): :param max_iters: The maximum number of iterations of the optimization loop. """ - def __init__(self, epsilon: float = 0.001, max_iters: int = 100): + def __init__(self, epsilon: float = 0.001, max_iters: int = 100) -> None: super().__init__(MGDAWeighting(epsilon=epsilon, max_iters=max_iters)) self._epsilon = epsilon self._max_iters = max_iters @@ -38,7 +38,7 @@ class MGDAWeighting(Weighting[PSDMatrix]): :param max_iters: The maximum number of iterations of the optimization loop. """ - def __init__(self, epsilon: float = 0.001, max_iters: int = 100): + def __init__(self, epsilon: float = 0.001, max_iters: int = 100) -> None: super().__init__() self.epsilon = epsilon self.max_iters = max_iters diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 83455245..06e1293d 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -77,7 +77,7 @@ def __init__( max_norm: float = 1.0, update_weights_every: int = 1, optim_niter: int = 20, - ): + ) -> None: super().__init__( weighting=_NashMTLWeighting( n_tasks=n_tasks, @@ -126,7 +126,7 @@ def __init__( max_norm: float, update_weights_every: int, optim_niter: int, - ): + ) -> None: super().__init__() self.n_tasks = n_tasks diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index d6cc3f10..0f1241df 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -16,7 +16,7 @@ class PCGrad(GramianWeightedAggregator): `Gradient Surgery for Multi-Task Learning `_. """ - def __init__(self): + def __init__(self) -> None: super().__init__(PCGradWeighting()) # This prevents running into a RuntimeError due to modifying stored tensors in place. diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 53ef188c..734dfc17 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -16,7 +16,7 @@ class Random(WeightedAggregator): `_. """ - def __init__(self): + def __init__(self) -> None: super().__init__(RandomWeighting()) diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 0d8bd5d6..aaf73f02 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -13,7 +13,7 @@ class Sum(WeightedAggregator): matrices. """ - def __init__(self): + def __init__(self) -> None: super().__init__(weighting=SumWeighting()) diff --git a/src/torchjd/aggregation/_trimmed_mean.py b/src/torchjd/aggregation/_trimmed_mean.py index 77d33c41..8dffe990 100644 --- a/src/torchjd/aggregation/_trimmed_mean.py +++ b/src/torchjd/aggregation/_trimmed_mean.py @@ -15,7 +15,7 @@ class TrimmedMean(Aggregator): input matrix (note that ``2 * trim_number`` values are removed from each column). """ - def __init__(self, trim_number: int): + def __init__(self, trim_number: int) -> None: super().__init__() if trim_number < 0: raise ValueError( diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 8234b3a8..45f760be 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -34,7 +34,7 @@ def __init__( norm_eps: float = 0.0001, reg_eps: float = 0.0001, solver: SUPPORTED_SOLVER = "quadprog", - ): + ) -> None: self._pref_vector = pref_vector self._norm_eps = norm_eps self._reg_eps = reg_eps @@ -78,7 +78,7 @@ def __init__( norm_eps: float = 0.0001, reg_eps: float = 0.0001, solver: SUPPORTED_SOLVER = "quadprog", - ): + ) -> None: super().__init__() self._pref_vector = pref_vector self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) diff --git a/src/torchjd/aggregation/_utils/non_differentiable.py b/src/torchjd/aggregation/_utils/non_differentiable.py index e3f85420..c5fb1ffc 100644 --- a/src/torchjd/aggregation/_utils/non_differentiable.py +++ b/src/torchjd/aggregation/_utils/non_differentiable.py @@ -2,7 +2,7 @@ class NonDifferentiableError(RuntimeError): - def __init__(self, module: nn.Module): + def __init__(self, module: nn.Module) -> None: super().__init__(f"Trying to differentiate through {module}, which is not differentiable.") diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index dd7c53ee..e321169c 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -20,7 +20,7 @@ class Weighting(nn.Module, ABC, Generic[_T]): generally its Gramian, of dimension :math:`m \times m`. """ - def __init__(self): + def __init__(self) -> None: super().__init__() @abstractmethod @@ -46,7 +46,7 @@ class _Composition(Weighting[_T]): output of the function. """ - def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutputT]): + def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutputT]) -> None: super().__init__() self.fn = fn self.weighting = weighting @@ -63,7 +63,7 @@ class GeneralizedWeighting(nn.Module, ABC): :math:`m_1 \times \dots \times m_k \times m_k \times \dots \times m_1`. """ - def __init__(self): + def __init__(self) -> None: super().__init__() @abstractmethod diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 610b3753..7b7eae96 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -183,7 +183,7 @@ def __init__( self, *modules: nn.Module, batch_dim: int | None, - ): + ) -> None: self._gramian_accumulator = GramianAccumulator() self._target_edges = EdgeRegistry() self._batch_dim = batch_dim diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 829e5da3..cdc7ce93 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -29,7 +29,7 @@ def reset(self) -> None: class JacobianBasedGramianComputer(GramianComputer, ABC): - def __init__(self, jacobian_computer: JacobianComputer): + def __init__(self, jacobian_computer: JacobianComputer) -> None: self.jacobian_computer = jacobian_computer @@ -39,7 +39,7 @@ class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): the gramian. """ - def __init__(self, jacobian_computer: JacobianComputer): + def __init__(self, jacobian_computer: JacobianComputer) -> None: super().__init__(jacobian_computer) self.remaining_counter = 0 self.summed_jacobian: Matrix | None = None diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 45cc71ba..347dcc1a 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -25,7 +25,7 @@ class JacobianComputer(ABC): :params module: The module to differentiate. """ - def __init__(self, module: nn.Module): + def __init__(self, module: nn.Module) -> None: self.module = module self.rg_params = dict[str, Parameter]() diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index e2958d93..20420ac6 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -33,7 +33,7 @@ def __init__( self, target_edges: EdgeRegistry, gramian_accumulator: GramianAccumulator, - ): + ) -> None: self._target_edges = target_edges self._gramian_accumulator = gramian_accumulator self.gramian_accumulation_phase = BoolRef(False) @@ -79,7 +79,7 @@ def remove_hooks(handles: list[TorchRemovableHandle]) -> None: class BoolRef: """Class wrapping a boolean value, acting as a reference to this boolean value.""" - def __init__(self, value: bool): + def __init__(self, value: bool) -> None: self.value = value def __bool__(self) -> bool: @@ -93,7 +93,7 @@ def __init__( target_edges: EdgeRegistry, gramian_accumulator: GramianAccumulator, gramian_computer: GramianComputer, - ): + ) -> None: self.gramian_accumulation_phase = gramian_accumulation_phase self.target_edges = target_edges self.gramian_accumulator = gramian_accumulator diff --git a/src/torchjd/autojac/_transform/_base.py b/src/torchjd/autojac/_transform/_base.py index 579b845c..4cba5120 100644 --- a/src/torchjd/autojac/_transform/_base.py +++ b/src/torchjd/autojac/_transform/_base.py @@ -67,7 +67,7 @@ class Composition(Transform): :param outer: The transform to apply second, to the result of ``inner``. """ - def __init__(self, outer: Transform, inner: Transform): + def __init__(self, outer: Transform, inner: Transform) -> None: self.outer = outer self.inner = inner @@ -92,7 +92,7 @@ class Conjunction(Transform): :param transforms: The transforms to apply. Their outputs should have disjoint sets of keys. """ - def __init__(self, transforms: Sequence[Transform]): + def __init__(self, transforms: Sequence[Transform]) -> None: self.transforms = transforms def __str__(self) -> str: diff --git a/src/torchjd/autojac/_transform/_diagonalize.py b/src/torchjd/autojac/_transform/_diagonalize.py index 7954d7ce..11c951de 100644 --- a/src/torchjd/autojac/_transform/_diagonalize.py +++ b/src/torchjd/autojac/_transform/_diagonalize.py @@ -51,7 +51,7 @@ class Diagonalize(Transform): Jacobians. """ - def __init__(self, key_order: OrderedSet[Tensor]): + def __init__(self, key_order: OrderedSet[Tensor]) -> None: self.key_order = key_order self.indices: list[tuple[int, int]] = [] begin = 0 diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/autojac/_transform/_differentiate.py index 1ce26438..458bd8d0 100644 --- a/src/torchjd/autojac/_transform/_differentiate.py +++ b/src/torchjd/autojac/_transform/_differentiate.py @@ -31,7 +31,7 @@ def __init__( inputs: OrderedSet[Tensor], retain_graph: bool, create_graph: bool, - ): + ) -> None: self.outputs = list(outputs) self.inputs = list(inputs) self.retain_graph = retain_graph diff --git a/src/torchjd/autojac/_transform/_grad.py b/src/torchjd/autojac/_transform/_grad.py index e61d4748..a4bd4ff3 100644 --- a/src/torchjd/autojac/_transform/_grad.py +++ b/src/torchjd/autojac/_transform/_grad.py @@ -31,7 +31,7 @@ def __init__( inputs: OrderedSet[Tensor], retain_graph: bool = False, create_graph: bool = False, - ): + ) -> None: super().__init__(outputs, inputs, retain_graph, create_graph) def _differentiate(self, grad_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...]: diff --git a/src/torchjd/autojac/_transform/_init.py b/src/torchjd/autojac/_transform/_init.py index 50833032..9da503ed 100644 --- a/src/torchjd/autojac/_transform/_init.py +++ b/src/torchjd/autojac/_transform/_init.py @@ -13,7 +13,7 @@ class Init(Transform): :param values: Tensors for which Gradients must be returned. """ - def __init__(self, values: AbstractSet[Tensor]): + def __init__(self, values: AbstractSet[Tensor]) -> None: self.values = values def __call__(self, _input: TensorDict, /) -> TensorDict: diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/autojac/_transform/_jac.py index 1f33d1d9..e245fae0 100644 --- a/src/torchjd/autojac/_transform/_jac.py +++ b/src/torchjd/autojac/_transform/_jac.py @@ -38,7 +38,7 @@ def __init__( chunk_size: int | None, retain_graph: bool = False, create_graph: bool = False, - ): + ) -> None: super().__init__(outputs, inputs, retain_graph, create_graph) self.chunk_size = chunk_size diff --git a/src/torchjd/autojac/_transform/_ordered_set.py b/src/torchjd/autojac/_transform/_ordered_set.py index c929cb45..e182df89 100644 --- a/src/torchjd/autojac/_transform/_ordered_set.py +++ b/src/torchjd/autojac/_transform/_ordered_set.py @@ -10,7 +10,7 @@ class OrderedSet(MutableSet[_T]): """Ordered collection of distinct elements.""" - def __init__(self, elements: Iterable[_T]): + def __init__(self, elements: Iterable[_T]) -> None: super().__init__() self.ordered_dict = OrderedDict[_T, None]([(element, None) for element in elements]) diff --git a/src/torchjd/autojac/_transform/_select.py b/src/torchjd/autojac/_transform/_select.py index b2e45caa..9df527ff 100644 --- a/src/torchjd/autojac/_transform/_select.py +++ b/src/torchjd/autojac/_transform/_select.py @@ -12,7 +12,7 @@ class Select(Transform): :param keys: The keys that should be included in the returned subset. """ - def __init__(self, keys: AbstractSet[Tensor]): + def __init__(self, keys: AbstractSet[Tensor]) -> None: self.keys = keys def __call__(self, tensor_dict: TensorDict, /) -> TensorDict: diff --git a/src/torchjd/autojac/_transform/_stack.py b/src/torchjd/autojac/_transform/_stack.py index a4152afc..ad628e5d 100644 --- a/src/torchjd/autojac/_transform/_stack.py +++ b/src/torchjd/autojac/_transform/_stack.py @@ -20,7 +20,7 @@ class Stack(Transform): to those dicts. """ - def __init__(self, transforms: Sequence[Transform]): + def __init__(self, transforms: Sequence[Transform]) -> None: self.transforms = transforms def __call__(self, input: TensorDict, /) -> TensorDict: diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index ac4ac060..f1df52d9 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -231,7 +231,7 @@ def test_lightning_integration(): from torchjd.autojac import jac_to_grad, mtl_backward class Model(LightningModule): - def __init__(self): + def __init__(self) -> None: super().__init__() self.feature_extractor = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) self.task1_head = Linear(3, 1) diff --git a/tests/plots/_utils.py b/tests/plots/_utils.py index e184d6d9..dc69bfda 100644 --- a/tests/plots/_utils.py +++ b/tests/plots/_utils.py @@ -7,7 +7,7 @@ class Plotter: - def __init__(self, aggregators: list[Aggregator], matrix: torch.Tensor, seed: int = 0): + def __init__(self, aggregators: list[Aggregator], matrix: torch.Tensor, seed: int = 0) -> None: self.aggregators = aggregators self.matrix = matrix self.seed = seed diff --git a/tests/unit/aggregation/_matrix_samplers.py b/tests/unit/aggregation/_matrix_samplers.py index 1b5cc8ab..68699cce 100644 --- a/tests/unit/aggregation/_matrix_samplers.py +++ b/tests/unit/aggregation/_matrix_samplers.py @@ -9,7 +9,7 @@ class MatrixSampler(ABC): """Abstract base class for sampling matrices of a given shape, rank.""" - def __init__(self, m: int, n: int, rank: int): + def __init__(self, m: int, n: int, rank: int) -> None: self._check_params(m, n, rank) self.m = m self.n = n diff --git a/tests/unit/autojac/_transform/test_base.py b/tests/unit/autojac/_transform/test_base.py index 254147bd..e1e47ad2 100644 --- a/tests/unit/autojac/_transform/test_base.py +++ b/tests/unit/autojac/_transform/test_base.py @@ -10,7 +10,7 @@ class FakeTransform(Transform): Fake ``Transform`` to test `check_keys` when composing and conjuncting. """ - def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]): + def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]) -> None: self._required_keys = required_keys self._output_keys = output_keys diff --git a/tests/unit/autojac/_transform/test_stack.py b/tests/unit/autojac/_transform/test_stack.py index 35e617d1..ef644fb5 100644 --- a/tests/unit/autojac/_transform/test_stack.py +++ b/tests/unit/autojac/_transform/test_stack.py @@ -12,7 +12,7 @@ class FakeGradientsTransform(Transform): """Transform that produces gradients filled with ones, for testing purposes.""" - def __init__(self, keys: Iterable[Tensor]): + def __init__(self, keys: Iterable[Tensor]) -> None: self.keys = set(keys) def __call__(self, _input: TensorDict, /) -> TensorDict: diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index cf3261f8..d83917cd 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -14,7 +14,7 @@ class ModuleFactory(Generic[_T]): - def __init__(self, architecture: type[_T], *args, **kwargs): + def __init__(self, architecture: type[_T], *args, **kwargs) -> None: self.architecture: type[_T] = architecture self.args = args self.kwargs = kwargs @@ -63,7 +63,7 @@ class OverlyNested(ShapedModule): INPUT_SHAPES = (9,) OUTPUT_SHAPES = (14,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.seq = nn.Sequential( nn.Sequential( @@ -95,7 +95,7 @@ class MultiInputSingleOutput(ShapedModule): INPUT_SHAPES = ((50,), (50,)) OUTPUT_SHAPES = (60,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix1 = nn.Parameter(torch.randn(50, 60)) self.matrix2 = nn.Parameter(torch.randn(50, 60)) @@ -112,7 +112,7 @@ class MultiInputMultiOutput(ShapedModule): INPUT_SHAPES = ((50,), (50,)) OUTPUT_SHAPES = ((60,), (70,)) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix1_1 = nn.Parameter(torch.randn(50, 60)) self.matrix2_1 = nn.Parameter(torch.randn(50, 60)) @@ -136,7 +136,7 @@ class SingleInputPyTreeOutput(ShapedModule): "third": ([((90,),)],), } - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix1 = nn.Parameter(torch.randn(50, 50)) self.matrix2 = nn.Parameter(torch.randn(50, 60)) @@ -161,7 +161,7 @@ class PyTreeInputSingleOutput(ShapedModule): } OUTPUT_SHAPES = (350,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix1 = nn.Parameter(torch.randn(10, 50)) self.matrix2 = nn.Parameter(torch.randn(20, 60)) @@ -203,7 +203,7 @@ class PyTreeInputPyTreeOutput(ShapedModule): "third": ([((90,),)],), } - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix1 = nn.Parameter(torch.randn(10, 50)) self.matrix2 = nn.Parameter(torch.randn(20, 60)) @@ -231,7 +231,7 @@ class SimpleBranched(ShapedModule): INPUT_SHAPES = (9,) OUTPUT_SHAPES = (16,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.relu = nn.ReLU() self.fc0 = nn.Linear(9, 13) @@ -257,7 +257,7 @@ class MISOBranched(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = MultiInputSingleOutput.OUTPUT_SHAPES - def __init__(self): + def __init__(self) -> None: super().__init__() self.miso = MultiInputSingleOutput() @@ -274,7 +274,7 @@ class MIMOBranched(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (130,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.mimo = MultiInputMultiOutput() @@ -291,7 +291,7 @@ class SIPOBranched(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (350,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.sipo = SingleInputPyTreeOutput() @@ -314,7 +314,7 @@ class PISOBranched(ShapedModule): INPUT_SHAPES = (86,) OUTPUT_SHAPES = (350,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.piso = PyTreeInputSingleOutput() @@ -342,7 +342,7 @@ class PIPOBranched(ShapedModule): INPUT_SHAPES = (86,) OUTPUT_SHAPES = (350,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.pipo = PyTreeInputPyTreeOutput() @@ -379,7 +379,7 @@ class WithNoTensorOutput(ShapedModule): OUTPUT_SHAPES = (10,) class _NoneOutput(nn.Module): - def __init__(self, shape: tuple[int, ...]): + def __init__(self, shape: tuple[int, ...]) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(shape)) @@ -387,7 +387,7 @@ def forward(self, _: PyTree) -> None: pass class _NonePyTreeOutput(nn.Module): - def __init__(self, shape: tuple[int, ...]): + def __init__(self, shape: tuple[int, ...]) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(shape)) @@ -395,7 +395,7 @@ def forward(self, _: PyTree) -> PyTree: return {"one": [None, ()], "two": None} class _EmptyTupleOutput(nn.Module): - def __init__(self, shape: tuple[int, ...]): + def __init__(self, shape: tuple[int, ...]) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(shape)) @@ -403,14 +403,14 @@ def forward(self, _: PyTree) -> tuple: return () class _EmptyPytreeOutput(nn.Module): - def __init__(self, shape: tuple[int, ...]): + def __init__(self, shape: tuple[int, ...]) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(shape)) def forward(self, _: PyTree) -> PyTree: return {"one": [(), ()], "two": [[], []]} - def __init__(self): + def __init__(self) -> None: super().__init__() self.none_output = self._NoneOutput((27, 10)) self.none_pytree_output = self._NonePyTreeOutput((27, 10)) @@ -432,7 +432,7 @@ class IntraModuleParamReuse(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(50, 10)) @@ -452,14 +452,14 @@ class _MatMulModule(nn.Module): that this parameter can be used in other modules too. """ - def __init__(self, matrix: nn.Parameter): + def __init__(self, matrix: nn.Parameter) -> None: super().__init__() self.matrix = matrix def forward(self, input: Tensor): return input @ self.matrix - def __init__(self): + def __init__(self) -> None: super().__init__() matrix = nn.Parameter(torch.randn(50, 10)) self.module1 = self._MatMulModule(matrix) @@ -475,7 +475,7 @@ class ModuleReuse(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.module = nn.Linear(50, 10) @@ -489,7 +489,7 @@ class SomeUnusedParam(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.unused_param = nn.Parameter(torch.randn(50, 10)) self.matrix = nn.Parameter(torch.randn(50, 10)) @@ -507,7 +507,7 @@ class SomeFrozenParam(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.frozen_param = nn.Parameter(torch.randn(50, 10), requires_grad=False) self.matrix = nn.Parameter(torch.randn(50, 10)) @@ -524,7 +524,7 @@ class WithSomeFrozenModule(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.non_frozen = nn.Linear(50, 10) self.all_frozen = nn.Linear(50, 10) @@ -553,7 +553,7 @@ class SomeFrozenParamAndUnusedTrainableParam(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.frozen_param = nn.Parameter(torch.randn(50, 10), requires_grad=False) self.non_frozen_param = nn.Parameter(torch.randn(50, 10)) @@ -561,7 +561,7 @@ def __init__(self): def forward(self, input: Tensor) -> Tensor: return input @ self.frozen_param - def __init__(self): + def __init__(self) -> None: super().__init__() self.weird_module = self.SomeFrozenParamAndUnusedTrainableParam() self.normal_module = nn.Linear(10, 3) @@ -579,7 +579,7 @@ class MultiOutputWithFrozenBranch(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = ((10,), (10,)) - def __init__(self): + def __init__(self) -> None: super().__init__() self.frozen_param = nn.Parameter(torch.randn(50, 10), requires_grad=False) self.matrix = nn.Parameter(torch.randn(50, 10)) @@ -597,14 +597,14 @@ class WithBuffered(ShapedModule): class _Buffered(nn.Module): buffer: Tensor - def __init__(self): + def __init__(self) -> None: super().__init__() self.register_buffer("buffer", torch.tensor(1.5)) def forward(self, input: Tensor) -> Tensor: return input * self.buffer - def __init__(self): + def __init__(self) -> None: super().__init__() self.module_with_buffer = self._Buffered() self.linear = nn.Linear(27, 10) @@ -619,7 +619,7 @@ class Randomness(ShapedModule): INPUT_SHAPES = (9,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(9, 10)) @@ -635,7 +635,7 @@ class WithSideEffect(ShapedModule): INPUT_SHAPES = (9,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(9, 10)) self.register_buffer("buffer", torch.zeros((9,))) @@ -654,7 +654,7 @@ class SomeUnusedOutput(ShapedModule): INPUT_SHAPES = (9,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear1 = nn.Linear(9, 12) self.linear2 = nn.Linear(9, 10) @@ -671,7 +671,7 @@ class Ndim0Output(ShapedModule): INPUT_SHAPES = (5,) OUTPUT_SHAPES = () - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear = nn.Linear(5, 1) @@ -685,7 +685,7 @@ class Ndim1Output(ShapedModule): INPUT_SHAPES = (5,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear = nn.Linear(5, 3) @@ -699,7 +699,7 @@ class Ndim2Output(ShapedModule): INPUT_SHAPES = (5,) OUTPUT_SHAPES = (2, 3) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear1 = nn.Linear(5, 3) self.linear2 = nn.Linear(5, 3) @@ -714,7 +714,7 @@ class Ndim3Output(ShapedModule): INPUT_SHAPES = (6,) OUTPUT_SHAPES = (2, 3, 4) - def __init__(self): + def __init__(self) -> None: super().__init__() self.tensor = nn.Parameter(torch.randn(6, 2, 3, 4)) @@ -728,7 +728,7 @@ class Ndim4Output(ShapedModule): INPUT_SHAPES = (6,) OUTPUT_SHAPES = (2, 3, 4, 5) - def __init__(self): + def __init__(self) -> None: super().__init__() self.tensor = nn.Parameter(torch.randn(6, 2, 3, 4, 5)) @@ -742,7 +742,7 @@ class WithRNN(ShapedModule): INPUT_SHAPES = (20, 8) # Size 20, dim input_size (8) OUTPUT_SHAPES = (20, 5) # Size 20, dim hidden_size (5) - def __init__(self): + def __init__(self) -> None: super().__init__() self.rnn = nn.RNN(input_size=8, hidden_size=5, batch_first=True) @@ -757,7 +757,7 @@ class WithDropout(ShapedModule): INPUT_SHAPES = (3, 6, 6) OUTPUT_SHAPES = (3, 4, 4) - def __init__(self): + def __init__(self) -> None: super().__init__() self.conv2d = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) self.dropout = nn.Dropout2d(p=0.5) @@ -775,7 +775,7 @@ class ModelUsingSubmoduleParamsDirectly(ShapedModule): INPUT_SHAPES = (2,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear = nn.Linear(2, 3) @@ -791,7 +791,7 @@ class ModelAlsoUsingSubmoduleParamsDirectly(ShapedModule): INPUT_SHAPES = (2,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear = nn.Linear(2, 3) @@ -800,7 +800,7 @@ def forward(self, input: Tensor) -> Tensor: class _WithStringArg(nn.Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(2, 3)) @@ -816,7 +816,7 @@ class WithModuleWithStringArg(ShapedModule): INPUT_SHAPES = (2,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.with_string_arg = _WithStringArg() @@ -830,7 +830,7 @@ class WithModuleWithStringKwarg(ShapedModule): INPUT_SHAPES = (2,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.with_string_arg = _WithStringArg() @@ -839,7 +839,7 @@ def forward(self, input: Tensor) -> Tensor: class _WithHybridPyTreeArg(nn.Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.m0 = nn.Parameter(torch.randn(3, 3)) self.m1 = nn.Parameter(torch.randn(4, 3)) @@ -869,7 +869,7 @@ class WithModuleWithHybridPyTreeArg(ShapedModule): INPUT_SHAPES = (10,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear = nn.Linear(10, 18) self.with_string_arg = _WithHybridPyTreeArg() @@ -898,7 +898,7 @@ class WithModuleWithHybridPyTreeKwarg(ShapedModule): INPUT_SHAPES = (10,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear = nn.Linear(10, 18) self.with_string_arg = _WithHybridPyTreeArg() @@ -925,14 +925,14 @@ class WithModuleWithStringOutput(ShapedModule): OUTPUT_SHAPES = (3,) class WithStringOutput(nn.Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(2, 3)) def forward(self, input: Tensor) -> tuple[str, Tensor]: return "test", input @ self.matrix - def __init__(self): + def __init__(self) -> None: super().__init__() self.with_string_output = self.WithStringOutput() @@ -947,7 +947,7 @@ class WithMultiHeadAttention(ShapedModule): INPUT_SHAPES = ((20, 8), (10, 9), (10, 11)) OUTPUT_SHAPES = (20, 8) - def __init__(self): + def __init__(self) -> None: super().__init__() self.mha = nn.MultiheadAttention( embed_dim=8, @@ -970,7 +970,7 @@ class WithTransformer(ShapedModule): INPUT_SHAPES = ((10, 8), (20, 8)) OUTPUT_SHAPES = (20, 8) - def __init__(self): + def __init__(self) -> None: super().__init__() self.transformer = nn.Transformer( d_model=8, @@ -993,7 +993,7 @@ class WithTransformerLarge(ShapedModule): INPUT_SHAPES = ((10, 512), (20, 512)) OUTPUT_SHAPES = (20, 512) - def __init__(self): + def __init__(self) -> None: super().__init__() self.transformer = nn.Transformer( batch_first=True, @@ -1014,7 +1014,7 @@ class FreeParam(ShapedModule): INPUT_SHAPES = (15,) OUTPUT_SHAPES = (80,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(15, 16)) # Free parameter self.relu = nn.ReLU() @@ -1041,7 +1041,7 @@ class NoFreeParam(ShapedModule): INPUT_SHAPES = (15,) OUTPUT_SHAPES = (80,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear0 = nn.Linear(15, 16, bias=False) self.relu = nn.ReLU() @@ -1071,7 +1071,7 @@ class Body(ShapedModule): INPUT_SHAPES = (3, 32, 32) OUTPUT_SHAPES = (1024,) - def __init__(self): + def __init__(self) -> None: super().__init__() layers = [ nn.Conv2d(3, 32, 3), @@ -1092,7 +1092,7 @@ class Head(ShapedModule): INPUT_SHAPES = (1024,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() layers = [ nn.Linear(1024, 128), @@ -1107,7 +1107,7 @@ def forward(self, input: Tensor) -> Tensor: INPUT_SHAPES = Body.INPUT_SHAPES OUTPUT_SHAPES = Head.OUTPUT_SHAPES - def __init__(self): + def __init__(self) -> None: super().__init__() self.body = self.Body() self.head = self.Head() @@ -1128,7 +1128,7 @@ class AlexNet(ShapedModule): INPUT_SHAPES = (3, 224, 224) OUTPUT_SHAPES = (1000,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.alexnet = torchvision.models.alexnet() @@ -1145,7 +1145,7 @@ class InstanceNormResNet18(ShapedModule): INPUT_SHAPES = (3, 224, 224) OUTPUT_SHAPES = (1000,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.resnet18 = torchvision.models.resnet18( norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True), @@ -1161,7 +1161,7 @@ class GroupNormMobileNetV3Small(ShapedModule): INPUT_SHAPES = (3, 224, 224) OUTPUT_SHAPES = (1000,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.mobile_net = torchvision.models.mobilenet_v3_small( norm_layer=partial(nn.GroupNorm, 2, affine=True), @@ -1177,7 +1177,7 @@ class SqueezeNet(ShapedModule): INPUT_SHAPES = (3, 224, 224) OUTPUT_SHAPES = (1000,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.squeezenet = torchvision.models.squeezenet1_0() @@ -1191,7 +1191,7 @@ class InstanceNormMobileNetV2(ShapedModule): INPUT_SHAPES = (3, 224, 224) OUTPUT_SHAPES = (1000,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.mobilenet = torchvision.models.mobilenet_v2( norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True), diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 57f9b90a..45f1417d 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -162,7 +162,7 @@ class CloneParams: algorithm rather than a module-based algorithm. """ - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module) -> None: self.model = model self.clones = list[nn.Parameter]() self._module_to_original_params = dict[nn.Module, dict[str, nn.Parameter]]() From 0cb67ca9e8acbff684d51e736baee34e06300adc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:25:34 +0100 Subject: [PATCH 03/22] Improve ANN rule configuration --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 30a86f49..e20ac573 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,12 +157,16 @@ ignore = [ "RUF012", # Mutable default value for class attribute (a bit tedious to fix) "RET504", # Unnecessary assignment return statement "COM812", # Trailing comma missing (conflicts with formatter, see https://github.com/astral-sh/ruff/issues/9216) + "ANN401", # Prevent annotating as Any (we rarely do that, and when we do it's hard to find an alternative) ] [tool.ruff.lint.per-file-ignores] "**/conftest.py" = ["ARG"] # Can't change argument names in the functions pytest expects "tests/doc/test_rst.py" = ["ARG"] # For the lightning example +[tool.ruff.lint.flake8-annotations] +suppress-dummy-args = true + [tool.ruff.lint.isort] combine-as-imports = true From f67f8d27be3e64091e770b4fc28ae1457b622aa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:25:50 +0100 Subject: [PATCH 04/22] Fix untyped ctx --- src/torchjd/autogram/_module_hook_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 20420ac6..1afee670 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -1,5 +1,5 @@ import weakref -from typing import cast +from typing import Any, cast import torch from torch import Tensor, nn @@ -170,7 +170,7 @@ def forward( # tuple[BoolRef, GramianComputer, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, *tuple[Tensor, ...]] @staticmethod def setup_context( - ctx, + ctx: Any, inputs: tuple, _, ) -> None: # type: ignore[reportIncompatibleMethodOverride] @@ -182,7 +182,7 @@ def setup_context( ctx.rg_outputs = inputs[5:] @staticmethod - def backward(ctx, *grad_outputs: Tensor) -> tuple: + def backward(ctx: Any, *grad_outputs: Tensor) -> tuple: # For python > 3.10: -> tuple[None, None, None, None, None, *tuple[Tensor, ...]] if ctx.gramian_accumulation_phase: From 905c658a863ef51f15a7dc66415b30807907ee35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:38:46 +0100 Subject: [PATCH 05/22] Add missing -> None in tests --- tests/conftest.py | 6 +- tests/doc/test_aggregation.py | 4 +- tests/doc/test_autogram.py | 2 +- tests/doc/test_backward.py | 4 +- tests/doc/test_jac.py | 6 +- tests/doc/test_jac_to_grad.py | 2 +- tests/doc/test_rst.py | 24 ++++---- tests/plots/static_plotter.py | 2 +- tests/profiling/run_profiler.py | 6 +- tests/profiling/speed_grad_vs_jac_vs_gram.py | 28 ++++----- tests/unit/aggregation/_asserts.py | 2 +- .../unit/aggregation/_utils/test_dual_cone.py | 8 +-- .../aggregation/_utils/test_pref_vector.py | 4 +- .../unit/aggregation/test_aggregator_bases.py | 2 +- tests/unit/aggregation/test_aligned_mtl.py | 8 +-- tests/unit/aggregation/test_cagrad.py | 10 ++-- tests/unit/aggregation/test_config.py | 10 ++-- tests/unit/aggregation/test_constant.py | 14 +++-- tests/unit/aggregation/test_dualproj.py | 12 ++-- tests/unit/aggregation/test_graddrop.py | 12 ++-- tests/unit/aggregation/test_imtl_g.py | 10 ++-- tests/unit/aggregation/test_krum.py | 10 ++-- tests/unit/aggregation/test_mean.py | 10 ++-- tests/unit/aggregation/test_mgda.py | 10 ++-- tests/unit/aggregation/test_nash_mtl.py | 8 +-- tests/unit/aggregation/test_pcgrad.py | 8 +-- tests/unit/aggregation/test_random.py | 6 +- tests/unit/aggregation/test_sum.py | 10 ++-- tests/unit/aggregation/test_trimmed_mean.py | 10 ++-- tests/unit/aggregation/test_upgrad.py | 14 ++--- tests/unit/aggregation/test_values.py | 4 +- tests/unit/autogram/test_edge_registry.py | 8 +-- tests/unit/autogram/test_engine.py | 30 +++++----- tests/unit/autogram/test_gramian_utils.py | 12 ++-- .../autojac/_transform/test_accumulate.py | 22 +++---- tests/unit/autojac/_transform/test_base.py | 10 ++-- .../autojac/_transform/test_diagonalize.py | 8 +-- tests/unit/autojac/_transform/test_grad.py | 28 ++++----- tests/unit/autojac/_transform/test_init.py | 8 +-- .../autojac/_transform/test_interactions.py | 20 +++---- tests/unit/autojac/_transform/test_jac.py | 24 ++++---- tests/unit/autojac/_transform/test_select.py | 6 +- tests/unit/autojac/_transform/test_stack.py | 8 +-- tests/unit/autojac/test_backward.py | 38 ++++++------ tests/unit/autojac/test_jac.py | 38 ++++++------ tests/unit/autojac/test_jac_to_grad.py | 14 ++--- tests/unit/autojac/test_mtl_backward.py | 60 +++++++++---------- tests/unit/autojac/test_utils.py | 28 ++++----- tests/unit/linalg/test_gramian.py | 18 +++--- tests/unit/test_deprecations.py | 2 +- tests/utils/forward_backwards.py | 2 +- 51 files changed, 329 insertions(+), 321 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 25e8281c..478c48cf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,16 +30,16 @@ def fix_randomness() -> None: torch.use_deterministic_algorithms(True) -def pytest_addoption(parser): +def pytest_addoption(parser) -> None: parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") -def pytest_configure(config): +def pytest_configure(config) -> None: config.addinivalue_line("markers", "slow: mark test as slow to run") config.addinivalue_line("markers", "xfail_if_cuda: mark test as xfail if running on cuda") -def pytest_collection_modifyitems(config, items): +def pytest_collection_modifyitems(config, items) -> None: skip_slow = mark.skip(reason="Slow test. Use --runslow to run it.") xfail_cuda = mark.xfail(reason=f"Test expected to fail on {DEVICE}") for item in items: diff --git a/tests/doc/test_aggregation.py b/tests/doc/test_aggregation.py index 75ef0ddc..a4219e4e 100644 --- a/tests/doc/test_aggregation.py +++ b/tests/doc/test_aggregation.py @@ -4,7 +4,7 @@ from torch.testing import assert_close -def test_aggregation_and_weighting(): +def test_aggregation_and_weighting() -> None: from torch import tensor from torchjd.aggregation import UPGrad, UPGradWeighting @@ -22,7 +22,7 @@ def test_aggregation_and_weighting(): assert_close(weights, tensor([1.1109, 0.7894]), rtol=0, atol=1e-4) -def test_generalized_weighting(): +def test_generalized_weighting() -> None: from torch import ones from torchjd.aggregation import Flattening, UPGradWeighting diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index 43651824..a0861e24 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -1,7 +1,7 @@ """This file contains tests for the usage examples related to autogram.""" -def test_engine(): +def test_engine() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index 2416210e..d08d2c2f 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -6,7 +6,7 @@ from utils.asserts import assert_jac_close -def test_backward(): +def test_backward() -> None: import torch from torchjd.autojac import backward @@ -21,7 +21,7 @@ def test_backward(): assert_jac_close(param, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) -def test_backward2(): +def test_backward2() -> None: import torch from torchjd.autojac import backward diff --git a/tests/doc/test_jac.py b/tests/doc/test_jac.py index e195c233..1a0b79a2 100644 --- a/tests/doc/test_jac.py +++ b/tests/doc/test_jac.py @@ -5,7 +5,7 @@ from torch.testing import assert_close -def test_jac(): +def test_jac() -> None: import torch from torchjd.autojac import jac @@ -20,7 +20,7 @@ def test_jac(): assert_close(jacobians[0], torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) -def test_jac_2(): +def test_jac_2() -> None: import torch from torchjd.autojac import jac @@ -44,7 +44,7 @@ def test_jac_2(): ) -def test_jac_3(): +def test_jac_3() -> None: import torch from torchjd.autojac import jac diff --git a/tests/doc/test_jac_to_grad.py b/tests/doc/test_jac_to_grad.py index 1f064a6c..04ca3ac2 100644 --- a/tests/doc/test_jac_to_grad.py +++ b/tests/doc/test_jac_to_grad.py @@ -6,7 +6,7 @@ from utils.asserts import assert_grad_close -def test_jac_to_grad(): +def test_jac_to_grad() -> None: import torch from torchjd.aggregation import UPGrad diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index f1df52d9..fcd6f392 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -9,7 +9,7 @@ from pytest import mark -def test_amp(): +def test_amp() -> None: import torch from torch.amp import GradScaler from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -51,7 +51,7 @@ def test_amp(): optimizer.zero_grad() -def test_basic_usage(): +def test_basic_usage() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -79,7 +79,7 @@ def test_basic_usage(): optimizer.zero_grad() -def test_iwmtl(): +def test_iwmtl() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -125,8 +125,8 @@ def test_iwmtl(): optimizer.zero_grad() -def test_iwrm(): - def test_autograd(): +def test_iwrm() -> None: + def test_autograd() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -147,7 +147,7 @@ def test_autograd(): optimizer.step() optimizer.zero_grad() - def test_autojac(): + def test_autojac() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -173,7 +173,7 @@ def test_autojac(): optimizer.step() optimizer.zero_grad() - def test_autogram(): + def test_autogram() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -212,7 +212,7 @@ def test_autogram(): "ignore::lightning.fabric.utilities.warnings.PossibleUserWarning", ) @no_type_check # Typing is annoying with Lightning, which would make the example too hard to read. -def test_lightning_integration(): +def test_lightning_integration() -> None: # Extra ---------------------------------------------------------------------------------------- import logging @@ -278,7 +278,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler: trainer.fit(model=model, train_dataloaders=train_loader) -def test_monitoring(): +def test_monitoring() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.nn.functional import cosine_similarity @@ -331,7 +331,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch. optimizer.zero_grad() -def test_mtl(): +def test_mtl() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -369,7 +369,7 @@ def test_mtl(): optimizer.zero_grad() -def test_partial_jd(): +def test_partial_jd() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -402,7 +402,7 @@ def test_partial_jd(): optimizer.zero_grad() -def test_rnn(): +def test_rnn() -> None: import torch from torch.nn import RNN from torch.optim import SGD diff --git a/tests/plots/static_plotter.py b/tests/plots/static_plotter.py index c26c0d6b..928d8eea 100644 --- a/tests/plots/static_plotter.py +++ b/tests/plots/static_plotter.py @@ -32,7 +32,7 @@ def main( mean=False, dual_proj=False, mgda=False, -): +) -> None: angle1 = 2.6 angle2 = 0.3277 norm1 = 0.9 diff --git a/tests/profiling/run_profiler.py b/tests/profiling/run_profiler.py index 7707cb01..ba793610 100644 --- a/tests/profiling/run_profiler.py +++ b/tests/profiling/run_profiler.py @@ -105,7 +105,7 @@ def _save_and_print_trace( def profile_autojac(factory: ModuleFactory, batch_size: int) -> None: - def forward_backward_fn(model, inputs, loss_fn): + def forward_backward_fn(model, inputs, loss_fn) -> None: aggregator = UPGrad() autojac_forward_backward(model, inputs, loss_fn, aggregator) @@ -113,7 +113,7 @@ def forward_backward_fn(model, inputs, loss_fn): def profile_autogram(factory: ModuleFactory, batch_size: int) -> None: - def forward_backward_fn(model, inputs, loss_fn): + def forward_backward_fn(model, inputs, loss_fn) -> None: engine = Engine(model, batch_dim=0) weighting = UPGradWeighting() autogram_forward_backward(model, inputs, loss_fn, engine, weighting) @@ -121,7 +121,7 @@ def forward_backward_fn(model, inputs, loss_fn): profile_method("autogram", forward_backward_fn, factory, batch_size) -def main(): +def main() -> None: for factory, batch_size in PARAMETRIZATIONS: profile_autojac(factory, batch_size) print("\n" + "=" * 80 + "\n") diff --git a/tests/profiling/speed_grad_vs_jac_vs_gram.py b/tests/profiling/speed_grad_vs_jac_vs_gram.py index d68d67aa..b4f9825d 100644 --- a/tests/profiling/speed_grad_vs_jac_vs_gram.py +++ b/tests/profiling/speed_grad_vs_jac_vs_gram.py @@ -41,13 +41,13 @@ ] -def main(): +def main() -> None: for factory, batch_size in PARAMETRIZATIONS: compare_autograd_autojac_and_autogram_speed(factory, batch_size) print("\n") -def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int): +def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int) -> None: model = factory() inputs, targets = make_inputs_and_targets(model, batch_size) loss_fn = make_mse_loss_fn(targets) @@ -57,47 +57,47 @@ def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_si print(f"\nTimes for forward + backward on {factory} with BS={batch_size}, A={A} on {DEVICE}.") - def fn_autograd(): + def fn_autograd() -> None: autograd_forward_backward(model, inputs, loss_fn) - def init_fn_autograd(): + def init_fn_autograd() -> None: torch.cuda.empty_cache() gc.collect() fn_autograd() - def fn_autograd_gramian(): + def fn_autograd_gramian() -> None: autograd_gramian_forward_backward(model, inputs, loss_fn, W) - def init_fn_autograd_gramian(): + def init_fn_autograd_gramian() -> None: torch.cuda.empty_cache() gc.collect() fn_autograd_gramian() - def fn_autojac(): + def fn_autojac() -> None: autojac_forward_backward(model, inputs, loss_fn, A) - def init_fn_autojac(): + def init_fn_autojac() -> None: torch.cuda.empty_cache() gc.collect() fn_autojac() - def fn_autogram(): + def fn_autogram() -> None: autogram_forward_backward(model, inputs, loss_fn, engine, W) - def init_fn_autogram(): + def init_fn_autogram() -> None: torch.cuda.empty_cache() gc.collect() fn_autogram() - def optionally_cuda_sync(): + def optionally_cuda_sync() -> None: if DEVICE.type == "cuda": torch.cuda.synchronize() - def pre_fn(): + def pre_fn() -> None: model.zero_grad() optionally_cuda_sync() - def post_fn(): + def post_fn() -> None: optionally_cuda_sync() n_runs = 10 @@ -121,7 +121,7 @@ def post_fn(): print_times("autogram", autogram_times) -def noop(): +def noop() -> None: pass diff --git a/tests/unit/aggregation/_asserts.py b/tests/unit/aggregation/_asserts.py index 8c119674..4b85bf09 100644 --- a/tests/unit/aggregation/_asserts.py +++ b/tests/unit/aggregation/_asserts.py @@ -101,7 +101,7 @@ def assert_strongly_stationary( assert norm > threshold -def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor): +def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor) -> None: """ Tests empirically that a given non-differentiable `Aggregator` correctly raises a NonDifferentiableError whenever we try to backward through it. diff --git a/tests/unit/aggregation/_utils/test_dual_cone.py b/tests/unit/aggregation/_utils/test_dual_cone.py index 3923d3f0..68a8a75d 100644 --- a/tests/unit/aggregation/_utils/test_dual_cone.py +++ b/tests/unit/aggregation/_utils/test_dual_cone.py @@ -8,7 +8,7 @@ @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) -def test_solution_weights(shape: tuple[int, int]): +def test_solution_weights(shape: tuple[int, int]) -> None: r""" Tests that `_project_weights` returns valid weights corresponding to the projection onto the dual cone of a matrix with the specified shape. @@ -54,7 +54,7 @@ def test_solution_weights(shape: tuple[int, int]): @mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)]) @mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4]) -def test_scale_invariant(shape: tuple[int, int], scaling: float): +def test_scale_invariant(shape: tuple[int, int], scaling: float) -> None: """ Tests that `_project_weights` is invariant under scaling. """ @@ -70,7 +70,7 @@ def test_scale_invariant(shape: tuple[int, int], scaling: float): @mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) -def test_tensorization_shape(shape: tuple[int, ...]): +def test_tensorization_shape(shape: tuple[int, ...]) -> None: """ Tests that applying `_project_weights` on a tensor is equivalent to applying it on the tensor reshaped as matrix and to reshape the result back to the original tensor's shape. @@ -88,7 +88,7 @@ def test_tensorization_shape(shape: tuple[int, ...]): assert_close(W_matrix.reshape(shape), W_tensor) -def test_project_weight_vector_failure(): +def test_project_weight_vector_failure() -> None: """Tests that `_project_weight_vector` raises an error when the input G has too large values.""" large_J = np.random.randn(10, 100) * 1e5 diff --git a/tests/unit/aggregation/_utils/test_pref_vector.py b/tests/unit/aggregation/_utils/test_pref_vector.py index 159582dd..0871726a 100644 --- a/tests/unit/aggregation/_utils/test_pref_vector.py +++ b/tests/unit/aggregation/_utils/test_pref_vector.py @@ -21,6 +21,8 @@ (ones_([1, 1, 1]), raises(ValueError)), ], ) -def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext): +def test_pref_vector_to_weighting_check( + pref_vector: Tensor | None, expectation: ExceptionContext +) -> None: with expectation: _ = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) diff --git a/tests/unit/aggregation/test_aggregator_bases.py b/tests/unit/aggregation/test_aggregator_bases.py index b08c37a8..80c9aeaa 100644 --- a/tests/unit/aggregation/test_aggregator_bases.py +++ b/tests/unit/aggregation/test_aggregator_bases.py @@ -18,6 +18,6 @@ ([1, 2, 3, 4], raises(ValueError)), ], ) -def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext): +def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext) -> None: with expectation: Aggregator._check_is_matrix(randn_(shape)) diff --git a/tests/unit/aggregation/test_aligned_mtl.py b/tests/unit/aggregation/test_aligned_mtl.py index 70847cbf..db89207b 100644 --- a/tests/unit/aggregation/test_aligned_mtl.py +++ b/tests/unit/aggregation/test_aligned_mtl.py @@ -19,16 +19,16 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: AlignedMTL, matrix: Tensor): +def test_expected_structure(aggregator: AlignedMTL, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor): +def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = AlignedMTL(pref_vector=None) assert repr(A) == "AlignedMTL(pref_vector=None, scale_mode='min')" assert str(A) == "AlignedMTL" @@ -38,7 +38,7 @@ def test_representations(): assert str(A) == "AlignedMTL([1., 2., 3.])" -def test_invalid_scale_mode(): +def test_invalid_scale_mode() -> None: aggregator = AlignedMTL(scale_mode="test") # type: ignore[arg-type] matrix = ones_(3, 4) with raises(ValueError, match=r"Invalid scale_mode=.*Expected"): diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index e77cb729..c7d18b1f 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -23,17 +23,17 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: CAGrad, matrix: Tensor): +def test_expected_structure(aggregator: CAGrad, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: CAGrad, matrix: Tensor): +def test_non_differentiable(aggregator: CAGrad, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], non_conflicting_pairs_1 + non_conflicting_pairs_2) -def test_non_conflicting(aggregator: CAGrad, matrix: Tensor): +def test_non_conflicting(aggregator: CAGrad, matrix: Tensor) -> None: """Tests that CAGrad is non-conflicting when c >= 1 (it should not hold when c < 1).""" assert_non_conflicting(aggregator, matrix) @@ -48,12 +48,12 @@ def test_non_conflicting(aggregator: CAGrad, matrix: Tensor): (50.0, does_not_raise()), ], ) -def test_c_check(c: float, expectation: ExceptionContext): +def test_c_check(c: float, expectation: ExceptionContext) -> None: with expectation: _ = CAGrad(c=c) -def test_representations(): +def test_representations() -> None: A = CAGrad(c=0.5, norm_eps=0.0001) assert repr(A) == "CAGrad(c=0.5, norm_eps=0.0001)" assert str(A) == "CAGrad0.5" diff --git a/tests/unit/aggregation/test_config.py b/tests/unit/aggregation/test_config.py index 69cc4af1..2db2ea0f 100644 --- a/tests/unit/aggregation/test_config.py +++ b/tests/unit/aggregation/test_config.py @@ -20,26 +20,26 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: ConFIG, matrix: Tensor): +def test_expected_structure(aggregator: ConFIG, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: ConFIG, matrix: Tensor): +def test_permutation_invariant(aggregator: ConFIG, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_linear_under_scaling(aggregator: ConFIG, matrix: Tensor): +def test_linear_under_scaling(aggregator: ConFIG, matrix: Tensor) -> None: assert_linear_under_scaling(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: ConFIG, matrix: Tensor): +def test_non_differentiable(aggregator: ConFIG, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = ConFIG() assert repr(A) == "ConFIG(pref_vector=None)" assert str(A) == "ConFIG" diff --git a/tests/unit/aggregation/test_constant.py b/tests/unit/aggregation/test_constant.py index 984b7a1f..aa1332fc 100644 --- a/tests/unit/aggregation/test_constant.py +++ b/tests/unit/aggregation/test_constant.py @@ -28,17 +28,17 @@ def _make_aggregator(matrix: Tensor) -> Constant: @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: Constant, matrix: Tensor): +def test_expected_structure(aggregator: Constant, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_linear_under_scaling(aggregator: Constant, matrix: Tensor): +def test_linear_under_scaling(aggregator: Constant, matrix: Tensor) -> None: assert_linear_under_scaling(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) -def test_strongly_stationary(aggregator: Constant, matrix: Tensor): +def test_strongly_stationary(aggregator: Constant, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) @@ -57,7 +57,7 @@ def test_strongly_stationary(aggregator: Constant, matrix: Tensor): ([1, 1, 1, 1, 1], raises(ValueError)), ], ) -def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionContext): +def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionContext) -> None: weights = ones_(weights_shape) with expectation: _ = Constant(weights=weights) @@ -75,7 +75,9 @@ def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionCon ([5], 4, raises(ValueError)), ], ) -def test_matrix_shape_check(weights_shape: list[int], n_rows: int, expectation: ExceptionContext): +def test_matrix_shape_check( + weights_shape: list[int], n_rows: int, expectation: ExceptionContext +) -> None: matrix = ones_([n_rows, 5]) weights = ones_(weights_shape) aggregator = Constant(weights) @@ -84,7 +86,7 @@ def test_matrix_shape_check(weights_shape: list[int], n_rows: int, expectation: _ = aggregator(matrix) -def test_representations(): +def test_representations() -> None: A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu")) assert repr(A) == "Constant(weights=tensor([1., 2.]))" assert str(A) == "Constant([1., 2.])" diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 0f4407d2..5bd0e71a 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -21,31 +21,31 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: DualProj, matrix: Tensor): +def test_expected_structure(aggregator: DualProj, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_non_conflicting(aggregator: DualProj, matrix: Tensor): +def test_non_conflicting(aggregator: DualProj, matrix: Tensor) -> None: assert_non_conflicting(aggregator, matrix, atol=1e-04, rtol=1e-04) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: DualProj, matrix: Tensor): +def test_permutation_invariant(aggregator: DualProj, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=2e-07, rtol=2e-07) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) -def test_strongly_stationary(aggregator: DualProj, matrix: Tensor): +def test_strongly_stationary(aggregator: DualProj, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix, threshold=3e-03) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: DualProj, matrix: Tensor): +def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") assert ( repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" diff --git a/tests/unit/aggregation/test_graddrop.py b/tests/unit/aggregation/test_graddrop.py index 59e6e1ae..2868dca0 100644 --- a/tests/unit/aggregation/test_graddrop.py +++ b/tests/unit/aggregation/test_graddrop.py @@ -18,12 +18,12 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: GradDrop, matrix: Tensor): +def test_expected_structure(aggregator: GradDrop, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: GradDrop, matrix: Tensor): +def test_non_differentiable(aggregator: GradDrop, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) @@ -42,7 +42,7 @@ def test_non_differentiable(aggregator: GradDrop, matrix: Tensor): ([1, 1, 1, 1, 1], raises(ValueError)), ], ) -def test_leak_shape_check(leak_shape: list[int], expectation: ExceptionContext): +def test_leak_shape_check(leak_shape: list[int], expectation: ExceptionContext) -> None: leak = ones_(leak_shape) with expectation: _ = GradDrop(leak=leak) @@ -60,7 +60,9 @@ def test_leak_shape_check(leak_shape: list[int], expectation: ExceptionContext): ([5], 4, raises(ValueError)), ], ) -def test_matrix_shape_check(leak_shape: list[int], n_rows: int, expectation: ExceptionContext): +def test_matrix_shape_check( + leak_shape: list[int], n_rows: int, expectation: ExceptionContext +) -> None: matrix = ones_([n_rows, 5]) leak = ones_(leak_shape) aggregator = GradDrop(leak=leak) @@ -69,7 +71,7 @@ def test_matrix_shape_check(leak_shape: list[int], n_rows: int, expectation: Exc _ = aggregator(matrix) -def test_representations(): +def test_representations() -> None: A = GradDrop(leak=torch.tensor([0.0, 1.0], device="cpu")) assert re.match( r"GradDrop\(f=, leak=tensor\(\[0\., 1\.\]\)\)", diff --git a/tests/unit/aggregation/test_imtl_g.py b/tests/unit/aggregation/test_imtl_g.py index e9ba838c..03c41d5e 100644 --- a/tests/unit/aggregation/test_imtl_g.py +++ b/tests/unit/aggregation/test_imtl_g.py @@ -18,21 +18,21 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: IMTLG, matrix: Tensor): +def test_expected_structure(aggregator: IMTLG, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: IMTLG, matrix: Tensor): +def test_permutation_invariant(aggregator: IMTLG, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: IMTLG, matrix: Tensor): +def test_non_differentiable(aggregator: IMTLG, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) -def test_imtlg_zero(): +def test_imtlg_zero() -> None: """ Tests that IMTLG correctly returns the 0 vector in the special case where input matrix only consists of zeros. @@ -43,7 +43,7 @@ def test_imtlg_zero(): assert_close(A(J), zeros_(3)) -def test_representations(): +def test_representations() -> None: A = IMTLG() assert repr(A) == "IMTLG()" assert str(A) == "IMTLG" diff --git a/tests/unit/aggregation/test_krum.py b/tests/unit/aggregation/test_krum.py index 48fa4019..4097f2eb 100644 --- a/tests/unit/aggregation/test_krum.py +++ b/tests/unit/aggregation/test_krum.py @@ -15,7 +15,7 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: Krum, matrix: Tensor): +def test_expected_structure(aggregator: Krum, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @@ -29,7 +29,7 @@ def test_expected_structure(aggregator: Krum, matrix: Tensor): (5, does_not_raise()), ], ) -def test_n_byzantine_check(n_byzantine: int, expectation: ExceptionContext): +def test_n_byzantine_check(n_byzantine: int, expectation: ExceptionContext) -> None: with expectation: _ = Krum(n_byzantine=n_byzantine, n_selected=1) @@ -44,7 +44,7 @@ def test_n_byzantine_check(n_byzantine: int, expectation: ExceptionContext): (5, does_not_raise()), ], ) -def test_n_selected_check(n_selected: int, expectation: ExceptionContext): +def test_n_selected_check(n_selected: int, expectation: ExceptionContext) -> None: with expectation: _ = Krum(n_byzantine=1, n_selected=n_selected) @@ -66,7 +66,7 @@ def test_matrix_shape_check( n_selected: int, n_rows: int, expectation: ExceptionContext, -): +) -> None: aggregator = Krum(n_byzantine=n_byzantine, n_selected=n_selected) matrix = ones_([n_rows, 5]) @@ -74,7 +74,7 @@ def test_matrix_shape_check( _ = aggregator(matrix) -def test_representations(): +def test_representations() -> None: A = Krum(n_byzantine=1, n_selected=2) assert repr(A) == "Krum(n_byzantine=1, n_selected=2)" assert str(A) == "Krum1-2" diff --git a/tests/unit/aggregation/test_mean.py b/tests/unit/aggregation/test_mean.py index 4d3fbf3a..88c28e93 100644 --- a/tests/unit/aggregation/test_mean.py +++ b/tests/unit/aggregation/test_mean.py @@ -17,26 +17,26 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: Mean, matrix: Tensor): +def test_expected_structure(aggregator: Mean, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: Mean, matrix: Tensor): +def test_permutation_invariant(aggregator: Mean, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_linear_under_scaling(aggregator: Mean, matrix: Tensor): +def test_linear_under_scaling(aggregator: Mean, matrix: Tensor) -> None: assert_linear_under_scaling(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) -def test_strongly_stationary(aggregator: Mean, matrix: Tensor): +def test_strongly_stationary(aggregator: Mean, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = Mean() assert repr(A) == "Mean()" assert str(A) == "Mean" diff --git a/tests/unit/aggregation/test_mgda.py b/tests/unit/aggregation/test_mgda.py index 2d1fe068..5c925b8f 100644 --- a/tests/unit/aggregation/test_mgda.py +++ b/tests/unit/aggregation/test_mgda.py @@ -19,17 +19,17 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: MGDA, matrix: Tensor): +def test_expected_structure(aggregator: MGDA, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_non_conflicting(aggregator: MGDA, matrix: Tensor): +def test_non_conflicting(aggregator: MGDA, matrix: Tensor) -> None: assert_non_conflicting(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: MGDA, matrix: Tensor): +def test_permutation_invariant(aggregator: MGDA, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) @@ -43,7 +43,7 @@ def test_permutation_invariant(aggregator: MGDA, matrix: Tensor): (50, 100), ], ) -def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]): +def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]) -> None: matrix = randn_(shape) gramian = compute_gramian(matrix) @@ -66,7 +66,7 @@ def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]): assert_close(positive_mu.norm(), mu.norm(), atol=1e-02, rtol=0.0) -def test_representations(): +def test_representations() -> None: A = MGDA(epsilon=0.001, max_iters=100) assert repr(A) == "MGDA(epsilon=0.001, max_iters=100)" assert str(A) == "MGDA" diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index 44e15400..a1200d46 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -29,18 +29,18 @@ def _make_aggregator(matrix: Tensor) -> NashMTL: "ignore:You are solving a parameterized problem that is not DPP.", ) @mark.parametrize(["aggregator", "matrix"], standard_pairs) -def test_expected_structure(aggregator: NashMTL, matrix: Tensor): +def test_expected_structure(aggregator: NashMTL, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.filterwarnings("ignore:You are solving a parameterized problem that is not DPP.") @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: NashMTL, matrix: Tensor): +def test_non_differentiable(aggregator: NashMTL, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) @mark.filterwarnings("ignore: You are solving a parameterized problem that is not DPP.") -def test_nash_mtl_reset(): +def test_nash_mtl_reset() -> None: """ Tests that the reset method of NashMTL correctly resets its internal state, by verifying that the result is the same after reset as it is right after instantiation. @@ -59,7 +59,7 @@ def test_nash_mtl_reset(): assert_close(result, expected) -def test_representations(): +def test_representations() -> None: A = NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5) assert repr(A) == "NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5)" assert str(A) == "NashMTL" diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index 57a9120c..b776071d 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -17,12 +17,12 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: PCGrad, matrix: Tensor): +def test_expected_structure(aggregator: PCGrad, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: PCGrad, matrix: Tensor): +def test_non_differentiable(aggregator: PCGrad, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) @@ -41,7 +41,7 @@ def test_non_differentiable(aggregator: PCGrad, matrix: Tensor): (2, 11100), ], ) -def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]): +def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: """ Tests that UPGradWeighting of a SumWeighting is equivalent to PCGradWeighting for matrices of 2 rows. @@ -64,7 +64,7 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]): assert_close(result, expected, atol=4e-04, rtol=0.0) -def test_representations(): +def test_representations() -> None: A = PCGrad() assert repr(A) == "PCGrad()" assert str(A) == "PCGrad" diff --git a/tests/unit/aggregation/test_random.py b/tests/unit/aggregation/test_random.py index d93929e2..77ab7f42 100644 --- a/tests/unit/aggregation/test_random.py +++ b/tests/unit/aggregation/test_random.py @@ -12,16 +12,16 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: Random, matrix: Tensor): +def test_expected_structure(aggregator: Random, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) -def test_strongly_stationary(aggregator: Random, matrix: Tensor): +def test_strongly_stationary(aggregator: Random, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = Random() assert repr(A) == "Random()" assert str(A) == "Random" diff --git a/tests/unit/aggregation/test_sum.py b/tests/unit/aggregation/test_sum.py index 99fe4e9f..386c507f 100644 --- a/tests/unit/aggregation/test_sum.py +++ b/tests/unit/aggregation/test_sum.py @@ -17,26 +17,26 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: Sum, matrix: Tensor): +def test_expected_structure(aggregator: Sum, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: Sum, matrix: Tensor): +def test_permutation_invariant(aggregator: Sum, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_linear_under_scaling(aggregator: Sum, matrix: Tensor): +def test_linear_under_scaling(aggregator: Sum, matrix: Tensor) -> None: assert_linear_under_scaling(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) -def test_strongly_stationary(aggregator: Sum, matrix: Tensor): +def test_strongly_stationary(aggregator: Sum, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = Sum() assert repr(A) == "Sum()" assert str(A) == "Sum" diff --git a/tests/unit/aggregation/test_trimmed_mean.py b/tests/unit/aggregation/test_trimmed_mean.py index cdeb9398..3a6ccb2b 100644 --- a/tests/unit/aggregation/test_trimmed_mean.py +++ b/tests/unit/aggregation/test_trimmed_mean.py @@ -15,12 +15,12 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: TrimmedMean, matrix: Tensor): +def test_expected_structure(aggregator: TrimmedMean, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: TrimmedMean, matrix: Tensor): +def test_permutation_invariant(aggregator: TrimmedMean, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) @@ -34,7 +34,7 @@ def test_permutation_invariant(aggregator: TrimmedMean, matrix: Tensor): (5, does_not_raise()), ], ) -def test_trim_number_check(trim_number: int, expectation: ExceptionContext): +def test_trim_number_check(trim_number: int, expectation: ExceptionContext) -> None: with expectation: _ = TrimmedMean(trim_number=trim_number) @@ -49,7 +49,7 @@ def test_trim_number_check(trim_number: int, expectation: ExceptionContext): (10, 5, raises(ValueError)), ], ) -def test_matrix_shape_check(n_rows: int, trim_number: int, expectation: ExceptionContext): +def test_matrix_shape_check(n_rows: int, trim_number: int, expectation: ExceptionContext) -> None: matrix = ones_([n_rows, 5]) aggregator = TrimmedMean(trim_number=trim_number) @@ -57,7 +57,7 @@ def test_matrix_shape_check(n_rows: int, trim_number: int, expectation: Exceptio _ = aggregator(matrix) -def test_representations(): +def test_representations() -> None: aggregator = TrimmedMean(trim_number=2) assert repr(aggregator) == "TrimmedMean(trim_number=2)" assert str(aggregator) == "TM2" diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 9fc480d2..1859b662 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -22,36 +22,36 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: UPGrad, matrix: Tensor): +def test_expected_structure(aggregator: UPGrad, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_non_conflicting(aggregator: UPGrad, matrix: Tensor): +def test_non_conflicting(aggregator: UPGrad, matrix: Tensor) -> None: assert_non_conflicting(aggregator, matrix, atol=4e-04, rtol=4e-04) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor): +def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=5e-07, rtol=5e-07) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor): +def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor) -> None: assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=6e-02, rtol=6e-02) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) -def test_strongly_stationary(aggregator: UPGrad, matrix: Tensor): +def test_strongly_stationary(aggregator: UPGrad, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix, threshold=5e-03) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: UPGrad, matrix: Tensor): +def test_non_differentiable(aggregator: UPGrad, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" assert str(A) == "UPGrad" diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index 719be0c3..860f313d 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -110,14 +110,14 @@ @mark.parametrize(["A", "J", "expected_output"], AGGREGATOR_PARAMETRIZATIONS) -def test_aggregator_output(A: Aggregator, J: Tensor, expected_output: Tensor): +def test_aggregator_output(A: Aggregator, J: Tensor, expected_output: Tensor) -> None: """Test that the output values of an aggregator are fixed (on cpu).""" assert_close(A(J), expected_output, rtol=0, atol=1e-4) @mark.parametrize(["W", "G", "expected_output"], WEIGHTING_PARAMETRIZATIONS) -def test_weighting_output(W: Weighting, G: Tensor, expected_output: Tensor): +def test_weighting_output(W: Weighting, G: Tensor, expected_output: Tensor) -> None: """Test that the output values of a weighting are fixed (on cpu).""" assert_close(W(G), expected_output, rtol=0, atol=1e-4) diff --git a/tests/unit/autogram/test_edge_registry.py b/tests/unit/autogram/test_edge_registry.py index 88d6da8c..56e5f720 100644 --- a/tests/unit/autogram/test_edge_registry.py +++ b/tests/unit/autogram/test_edge_registry.py @@ -4,7 +4,7 @@ from torchjd.autogram._edge_registry import EdgeRegistry -def test_all_edges_are_leaves1(): +def test_all_edges_are_leaves1() -> None: """Tests that get_leaf_edges works correctly when all edges are already leaves.""" a = randn_([3, 4], requires_grad=True) @@ -22,7 +22,7 @@ def test_all_edges_are_leaves1(): assert leaves == expected_leaves -def test_all_edges_are_leaves2(): +def test_all_edges_are_leaves2() -> None: """ Tests that get_leaf_edges works correctly when all edges are already leaves of the graph of edges leading to them, but are not leaves of the autograd graph. @@ -46,7 +46,7 @@ def test_all_edges_are_leaves2(): assert leaves == expected_leaves -def test_some_edges_are_not_leaves1(): +def test_some_edges_are_not_leaves1() -> None: """Tests that get_leaf_edges works correctly when some edges are leaves and some are not.""" a = randn_([3, 4], requires_grad=True) @@ -67,7 +67,7 @@ def test_some_edges_are_not_leaves1(): assert leaves == expected_leaves -def test_some_edges_are_not_leaves2(): +def test_some_edges_are_not_leaves2() -> None: """ Tests that get_leaf_edges works correctly when some edges are leaves and some are not. This time, not all tensors in the graph are registered so not all leavese in the graph have to be diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index b96824ec..76fdc41f 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -147,7 +147,7 @@ def _assert_gramian_is_equivalent_to_autograd( factory: ModuleFactory, batch_size: int, batch_dim: int | None, -): +) -> None: model_autograd, model_autogram = factory(), factory() engine = Engine(model_autogram, batch_dim=batch_dim) inputs, targets = make_inputs_and_targets(model_autograd, batch_size) @@ -191,7 +191,7 @@ def _get_losses_and_params_without_cross_terms( @mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) @mark.parametrize("batch_dim", [0, None]) -def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int | None): +def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int | None) -> None: """Tests that the autograd and the autogram engines compute the same gramian.""" _assert_gramian_is_equivalent_to_autograd(factory, batch_size, batch_dim) @@ -213,7 +213,7 @@ def test_compute_gramian_with_weird_modules( factory: ModuleFactory, batch_size: int, batch_dim: int | None, -): +) -> None: """ Tests that compute_gramian works even with some problematic modules when batch_dim is None. It is expected to fail on those when the engine uses the batched optimization (when batch_dim=0). @@ -237,7 +237,7 @@ def test_compute_gramian_unsupported_architectures( factory: ModuleFactory, batch_size: int, batch_dim: int | None, -): +) -> None: """ Tests compute_gramian on some architectures that are known to be unsupported. It is expected to fail. @@ -275,7 +275,7 @@ def test_compute_gramian_various_output_shapes( batch_dim: int | None, movedim_source: list[int], movedim_destination: list[int], -): +) -> None: """ Tests that the autograd and the autogram engines compute the same gramian when the output can have various different shapes, and can be batched in any of its dimensions. @@ -312,7 +312,7 @@ def _non_empty_subsets(S: set) -> list[list]: @mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"})) @mark.parametrize("batch_dim", [0, None]) -def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int | None): +def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int | None) -> None: """ Tests that the autograd and the autogram engines compute the same gramian when only a subset of the model parameters is specified. @@ -340,7 +340,9 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int @mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) @mark.parametrize("batch_dim", [0, None]) -def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch_dim: int | None): +def test_iwrm_steps_with_autogram( + factory: ModuleFactory, batch_size: int, batch_dim: int | None +) -> None: """Tests that the autogram engine doesn't raise any error during several IWRM iterations.""" n_iter = 3 @@ -365,7 +367,7 @@ def test_autograd_while_modules_are_hooked( batch_size: int, use_engine: bool, batch_dim: int | None, -): +) -> None: """ Tests that the hooks added when constructing the engine do not interfere with a simple autograd call. @@ -402,7 +404,7 @@ def test_autograd_while_modules_are_hooked( (ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), 0), ], ) -def test_incompatible_modules(factory: ModuleFactory, batch_dim: int | None): +def test_incompatible_modules(factory: ModuleFactory, batch_dim: int | None) -> None: """Tests that the engine cannot be constructed with incompatible modules.""" model = factory() @@ -410,7 +412,7 @@ def test_incompatible_modules(factory: ModuleFactory, batch_dim: int | None): _ = Engine(model, batch_dim=batch_dim) -def test_compute_gramian_manual(): +def test_compute_gramian_manual() -> None: """ Tests that the Gramian computed by the `Engine` equals to a manual computation of the expected Gramian. @@ -454,7 +456,7 @@ def test_compute_gramian_manual(): [1], ], ) -def test_reshape_equivariance(shape: list[int]): +def test_reshape_equivariance(shape: list[int]) -> None: """ Test equivariance of `compute_gramian` under reshape operation. More precisely, if we reshape the `output` to some `shape`, then the result is the same as reshaping the Gramian to the @@ -492,7 +494,7 @@ def test_reshape_equivariance(shape: list[int]): ([1, 1, 1], [1, 0], [0, 1]), ], ) -def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]): +def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]) -> None: """ Test equivariance of `compute_gramian` under movedim operation. More precisely, if we movedim the `output` on some dimensions, then the result is the same as movedim on the Gramian with the @@ -532,7 +534,7 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: ([4, 3, 1], 2), ], ) -def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): +def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int) -> None: """ Tests that for a vector with some batched dimensions, the gramian is the same if we use the appropriate `batch_dim` or if we don't use any. @@ -558,7 +560,7 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): @mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) -def test_batched_non_batched_equivalence_2(factory: ModuleFactory, batch_size: int): +def test_batched_non_batched_equivalence_2(factory: ModuleFactory, batch_size: int) -> None: """ Same as test_batched_non_batched_equivalence but on real architectures, and thus only between batch_size=0 and batch_size=None. diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index 5f74df8d..7d4c2216 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -26,7 +26,7 @@ ([6, 7, 9], [6, 7, 9]), ], ) -def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]): +def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]) -> None: """Tests that reshape_gramian is such that compute_gramian is equivariant to a reshape.""" original_matrix = randn_([*original_shape, 2]) @@ -55,7 +55,7 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int] ([6, 7, 9], [6, 7, 9]), ], ) -def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]): +def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]) -> None: matrix = randn_([*original_shape, 2]) gramian = compute_gramian(matrix, 1) reshaped_gramian = reshape(gramian, target_shape) @@ -72,7 +72,7 @@ def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]): [6, 7, 9], ], ) -def test_flatten_yields_matrix(shape: list[int]): +def test_flatten_yields_matrix(shape: list[int]) -> None: matrix = randn_([*shape, 2]) gramian = compute_gramian(matrix, 1) flattened_gramian = flatten(gramian) @@ -89,7 +89,7 @@ def test_flatten_yields_matrix(shape: list[int]): [6, 7, 9], ], ) -def test_flatten_yields_psd(shape: list[int]): +def test_flatten_yields_psd(shape: list[int]) -> None: matrix = randn_([*shape, 2]) gramian = compute_gramian(matrix, 1) flattened_gramian = flatten(gramian) @@ -114,7 +114,7 @@ def test_flatten_yields_psd(shape: list[int]): ([2, 2, 3], [0, 2, 1], [1, 0, 2]), ], ) -def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]): +def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]) -> None: """Tests that movedim_gramian is such that compute_gramian is equivariant to a movedim.""" original_matrix = randn_([*shape, 2]) @@ -146,7 +146,7 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: ([2, 2, 3], [0, 2, 1], [1, 0, 2]), ], ) -def test_movedim_yields_psd(shape: list[int], source: list[int], destination: list[int]): +def test_movedim_yields_psd(shape: list[int], source: list[int], destination: list[int]) -> None: matrix = randn_([*shape, 2]) gramian = compute_gramian(matrix, 1) moveddim_gramian = movedim(gramian, source, destination) diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index 8c179a89..6e2a137e 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -6,7 +6,7 @@ from torchjd.autojac._transform import AccumulateGrad, AccumulateJac -def test_single_grad_accumulation(): +def test_single_grad_accumulation() -> None: """ Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run once. @@ -27,7 +27,7 @@ def test_single_grad_accumulation(): @mark.parametrize("iterations", [1, 2, 4, 10, 13]) -def test_multiple_grad_accumulations(iterations: int): +def test_multiple_grad_accumulations(iterations: int) -> None: """ Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run `iterations` times. @@ -47,7 +47,7 @@ def test_multiple_grad_accumulations(iterations: int): assert_grad_close(key, iterations * value) -def test_accumulate_grad_fails_when_no_requires_grad(): +def test_accumulate_grad_fails_when_no_requires_grad() -> None: """ Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a tensor that does not require grad. @@ -63,7 +63,7 @@ def test_accumulate_grad_fails_when_no_requires_grad(): accumulate(input) -def test_accumulate_grad_fails_when_no_leaf_and_no_retains_grad(): +def test_accumulate_grad_fails_when_no_leaf_and_no_retains_grad() -> None: """ Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a tensor that is not a leaf and that does not retain grad. @@ -79,7 +79,7 @@ def test_accumulate_grad_fails_when_no_leaf_and_no_retains_grad(): accumulate(input) -def test_accumulate_grad_check_keys(): +def test_accumulate_grad_check_keys() -> None: """Tests that the `check_keys` method works correctly for AccumulateGrad.""" key = tensor_([1.0], requires_grad=True) @@ -89,7 +89,7 @@ def test_accumulate_grad_check_keys(): assert output_keys == set() -def test_single_jac_accumulation(): +def test_single_jac_accumulation() -> None: """ Tests that the AccumulateJac transform correctly accumulates jacobians in .jac fields when run once. @@ -110,7 +110,7 @@ def test_single_jac_accumulation(): @mark.parametrize("iterations", [1, 2, 4, 10, 13]) -def test_multiple_jac_accumulations(iterations: int): +def test_multiple_jac_accumulations(iterations: int) -> None: """ Tests that the AccumulateJac transform correctly accumulates jacobians in .jac fields when run `iterations` times. @@ -131,7 +131,7 @@ def test_multiple_jac_accumulations(iterations: int): assert_jac_close(key, iterations * value) -def test_accumulate_jac_fails_when_no_requires_grad(): +def test_accumulate_jac_fails_when_no_requires_grad() -> None: """ Tests that the AccumulateJac transform raises an error when it tries to populate a .jac of a tensor that does not require grad. @@ -147,7 +147,7 @@ def test_accumulate_jac_fails_when_no_requires_grad(): accumulate(input) -def test_accumulate_jac_fails_when_no_leaf_and_no_retains_grad(): +def test_accumulate_jac_fails_when_no_leaf_and_no_retains_grad() -> None: """ Tests that the AccumulateJac transform raises an error when it tries to populate a .jac of a tensor that is not a leaf and that does not retain grad. @@ -163,7 +163,7 @@ def test_accumulate_jac_fails_when_no_leaf_and_no_retains_grad(): accumulate(input) -def test_accumulate_jac_fails_when_shape_mismatch(): +def test_accumulate_jac_fails_when_shape_mismatch() -> None: """ Tests that the AccumulateJac transform raises an error when the jacobian shape does not match the parameter shape (ignoring the first dimension). @@ -179,7 +179,7 @@ def test_accumulate_jac_fails_when_shape_mismatch(): accumulate(input) -def test_accumulate_jac_check_keys(): +def test_accumulate_jac_check_keys() -> None: """Tests that the `check_keys` method works correctly for AccumulateJac.""" key = tensor_([1.0], requires_grad=True) diff --git a/tests/unit/autojac/_transform/test_base.py b/tests/unit/autojac/_transform/test_base.py index e1e47ad2..a04dd866 100644 --- a/tests/unit/autojac/_transform/test_base.py +++ b/tests/unit/autojac/_transform/test_base.py @@ -29,7 +29,7 @@ def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: return self._output_keys -def test_composition_check_keys(): +def test_composition_check_keys() -> None: """ Tests that `check_keys` works correctly for a composition of transforms: the inner transform's `output_keys` has to satisfy the outer transform's requirements. @@ -52,7 +52,7 @@ def test_composition_check_keys(): (t2 << t1).check_keys({a1}) -def test_conjunct_check_keys_1(): +def test_conjunct_check_keys_1() -> None: """ Tests that `check_keys` works correctly for a conjunction of transforms: all transforms should successfully check their keys. @@ -75,7 +75,7 @@ def test_conjunct_check_keys_1(): (t1 | t2 | t3).check_keys({a1, a2}) -def test_conjunct_check_keys_2(): +def test_conjunct_check_keys_2() -> None: """ Tests that `check_keys` works correctly for a conjunction of transforms: their `output_keys` should be disjoint. @@ -98,7 +98,7 @@ def test_conjunct_check_keys_2(): (t1 | t2 | t3).check_keys(set()) -def test_empty_conjunction(): +def test_empty_conjunction() -> None: """ Tests that it is possible to take the conjunction of no transform. This should return an empty dictionary. @@ -109,7 +109,7 @@ def test_empty_conjunction(): assert len(conjunction({})) == 0 -def test_str(): +def test_str() -> None: """ Tests that the __str__ method works correctly even for transform involving compositions and conjunctions. diff --git a/tests/unit/autojac/_transform/test_diagonalize.py b/tests/unit/autojac/_transform/test_diagonalize.py index c1b30d31..b4e8e255 100644 --- a/tests/unit/autojac/_transform/test_diagonalize.py +++ b/tests/unit/autojac/_transform/test_diagonalize.py @@ -6,7 +6,7 @@ from torchjd.autojac._transform import Diagonalize, OrderedSet, RequirementError -def test_single_input(): +def test_single_input() -> None: """Tests that the Diagonalize transform works when given a single input.""" key = tensor_([1.0, 2.0, 3.0]) @@ -21,7 +21,7 @@ def test_single_input(): assert_tensor_dicts_are_close(output, expected_output) -def test_multiple_inputs(): +def test_multiple_inputs() -> None: """Tests that the Diagonalize transform works when given multiple inputs.""" key1 = tensor_([[1.0, 2.0], [4.0, 5.0]]) @@ -77,7 +77,7 @@ def test_multiple_inputs(): assert_tensor_dicts_are_close(output, expected_output) -def test_permute_order(): +def test_permute_order() -> None: """ Tests that the Diagonalize transform outputs a permuted mapping when its keys are permuted. """ @@ -98,7 +98,7 @@ def test_permute_order(): assert_tensor_dicts_are_close(output, expected_output) -def test_check_keys(): +def test_check_keys() -> None: """ Tests that the `check_keys` method works correctly. The input_keys must match the stored considered keys. diff --git a/tests/unit/autojac/_transform/test_grad.py b/tests/unit/autojac/_transform/test_grad.py index f834f73a..daa35bdf 100644 --- a/tests/unit/autojac/_transform/test_grad.py +++ b/tests/unit/autojac/_transform/test_grad.py @@ -6,7 +6,7 @@ from torchjd.autojac._transform import Grad, OrderedSet, RequirementError -def test_single_input(): +def test_single_input() -> None: """ Tests that the Grad transform works correctly for a very simple example of differentiation. Here, the function considered is: `y = a * x`. We want to compute the derivative of `y` with @@ -26,7 +26,7 @@ def test_single_input(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_empty_inputs_1(): +def test_empty_inputs_1() -> None: """ Tests that the Grad transform works correctly when the `inputs` parameter is an empty `Iterable`. @@ -43,7 +43,7 @@ def test_empty_inputs_1(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_empty_inputs_2(): +def test_empty_inputs_2() -> None: """ Tests that the Grad transform works correctly when the `inputs` parameter is an empty `Iterable`. @@ -62,7 +62,7 @@ def test_empty_inputs_2(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_empty_outputs(): +def test_empty_outputs() -> None: """ Tests that the Grad transform works correctly when the `outputs` parameter is an empty `Iterable`. @@ -80,7 +80,7 @@ def test_empty_outputs(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_retain_graph(): +def test_retain_graph() -> None: """Tests that the `Grad` transform behaves as expected with the `retain_graph` flag.""" x = tensor_(5.0) @@ -100,7 +100,7 @@ def test_retain_graph(): grad_discard_graph(input) -def test_single_input_two_levels(): +def test_single_input_two_levels() -> None: """ Tests that the Grad transform works correctly when composed with another Grad transform. Here, the function considered is: `z = a * x1 * x2`, which is computed in 2 parts: `y = a * x1` @@ -125,7 +125,7 @@ def test_single_input_two_levels(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_empty_inputs_two_levels(): +def test_empty_inputs_two_levels() -> None: """ Tests that the Grad transform works correctly when the `inputs` parameter is an empty `Iterable`, with 2 composed Grad transforms. @@ -148,7 +148,7 @@ def test_empty_inputs_two_levels(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_vector_output(): +def test_vector_output() -> None: """ Tests that the Grad transform works correctly when the `outputs` contains a single vector. The input (grad_outputs) is not the same for both values of the output, so that this test also @@ -168,7 +168,7 @@ def test_vector_output(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_multiple_outputs(): +def test_multiple_outputs() -> None: """ Tests that the Grad transform works correctly when the `outputs` contains 2 scalars. The input (grad_outputs) is not the same for both outputs, so that this test also checks that @@ -189,7 +189,7 @@ def test_multiple_outputs(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_multiple_tensor_outputs(): +def test_multiple_tensor_outputs() -> None: """ Tests that the Grad transform works correctly when the `outputs` contains several tensors of different shapes. The input (grad_outputs) is not the same for all values of the outputs, so @@ -216,7 +216,7 @@ def test_multiple_tensor_outputs(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_composition_of_grads_is_grad(): +def test_composition_of_grads_is_grad() -> None: """ Tests that the composition of 2 Grad transforms is equivalent to computing the Grad directly in a single transform. @@ -243,7 +243,7 @@ def test_composition_of_grads_is_grad(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_conjunction_of_grads_is_grad(): +def test_conjunction_of_grads_is_grad() -> None: """ Tests that the conjunction of 2 Grad transforms is equivalent to computing the Grad directly in a single transform. @@ -267,7 +267,7 @@ def test_conjunction_of_grads_is_grad(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_create_graph(): +def test_create_graph() -> None: """Tests that the Grad transform behaves correctly when `create_graph` is set to `True`.""" a = tensor_(2.0, requires_grad=True) @@ -281,7 +281,7 @@ def test_create_graph(): assert gradients[a].requires_grad -def test_check_keys(): +def test_check_keys() -> None: """ Tests that the `check_keys` method works correctly: the input_keys should match the stored outputs. diff --git a/tests/unit/autojac/_transform/test_init.py b/tests/unit/autojac/_transform/test_init.py index 38e4a29e..1833c145 100644 --- a/tests/unit/autojac/_transform/test_init.py +++ b/tests/unit/autojac/_transform/test_init.py @@ -5,7 +5,7 @@ from torchjd.autojac._transform import Init, RequirementError -def test_single_input(): +def test_single_input() -> None: """ Tests that when there is a single key to initialize, the Init transform creates a TensorDict whose value is a tensor full of ones, of the same shape as its key. @@ -22,7 +22,7 @@ def test_single_input(): assert_tensor_dicts_are_close(output, expected_output) -def test_multiple_inputs(): +def test_multiple_inputs() -> None: """ Tests that when there are several keys to initialize, the Init transform creates a TensorDict whose values are tensors full of ones, of the same shape as their corresponding keys. @@ -42,7 +42,7 @@ def test_multiple_inputs(): assert_tensor_dicts_are_close(output, expected) -def test_conjunction_of_inits_is_init(): +def test_conjunction_of_inits_is_init() -> None: """ Tests that the conjunction of 2 Init transforms is equivalent to a single Init transform with multiple keys. @@ -63,7 +63,7 @@ def test_conjunction_of_inits_is_init(): assert_tensor_dicts_are_close(output, expected_output) -def test_check_keys(): +def test_check_keys() -> None: """Tests that the `check_keys` method works correctly: the input_keys should be empty.""" key = tensor_([1.0]) diff --git a/tests/unit/autojac/_transform/test_interactions.py b/tests/unit/autojac/_transform/test_interactions.py index 470f5d6d..3f9a725a 100644 --- a/tests/unit/autojac/_transform/test_interactions.py +++ b/tests/unit/autojac/_transform/test_interactions.py @@ -18,7 +18,7 @@ ) -def test_jac_is_stack_of_grads(): +def test_jac_is_stack_of_grads() -> None: """ Tests that the Jac transform (composed with a Diagonalize) is equivalent to a Stack of Grad and Select transforms. @@ -52,7 +52,7 @@ def test_jac_is_stack_of_grads(): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_single_differentiation(): +def test_single_differentiation() -> None: """ Tests that we can perform a single scalar differentiation with the composition of a Grad and an Init transform. @@ -72,7 +72,7 @@ def test_single_differentiation(): assert_tensor_dicts_are_close(output, expected_output) -def test_multiple_differentiations(): +def test_multiple_differentiations() -> None: """ Tests that we can perform multiple scalar differentiations with the conjunction of multiple Grad transforms, composed with an Init transform. @@ -100,7 +100,7 @@ def test_multiple_differentiations(): assert_tensor_dicts_are_close(output, expected_output) -def test_str(): +def test_str() -> None: """Tests that the __str__ method works correctly even for a complex transform.""" init = Init(set()) diag = Diagonalize(OrderedSet([])) @@ -110,7 +110,7 @@ def test_str(): assert str(transform) == "Jac ∘ Diagonalize ∘ Init" -def test_simple_conjunction(): +def test_simple_conjunction() -> None: """ Tests that the Conjunction transform works correctly with a simple example involving several Select transforms, whose keys form a partition of the keys of the input tensor dict. @@ -133,7 +133,7 @@ def test_simple_conjunction(): assert_tensor_dicts_are_close(output, expected_output) -def test_conjunction_is_commutative(): +def test_conjunction_is_commutative() -> None: """ Tests that the Conjunction transform gives the same result no matter the order in which its transforms are given. @@ -154,7 +154,7 @@ def test_conjunction_is_commutative(): assert_tensor_dicts_are_close(output, expected_output) -def test_conjunction_is_associative(): +def test_conjunction_is_associative() -> None: """ Tests that the Conjunction transform gives the same result no matter how it is parenthesized. """ @@ -184,7 +184,7 @@ def test_conjunction_is_associative(): assert_tensor_dicts_are_close(output, expected_output) -def test_conjunction_accumulate_select(): +def test_conjunction_accumulate_select() -> None: """ Tests that it is possible to conjunct an AccumulateGrad and a Select in this order. It is not trivial since the type of the TensorDict returned by the first transform @@ -206,7 +206,7 @@ def test_conjunction_accumulate_select(): assert_tensor_dicts_are_close(output, expected_output) -def test_equivalence_jac_grads(): +def test_equivalence_jac_grads() -> None: """ Tests that differentiation in parallel using `_jac` is equivalent to sequential differentiation using several calls to `_grad` and stacking the resulting gradients. @@ -248,7 +248,7 @@ def test_equivalence_jac_grads(): assert_close(jac_c, torch.stack([grad_1_c, grad_2_c])) -def test_stack_check_keys(): +def test_stack_check_keys() -> None: """ Tests that the `check_keys` method works correctly for a stack of transforms: all of them should successfully check their keys. diff --git a/tests/unit/autojac/_transform/test_jac.py b/tests/unit/autojac/_transform/test_jac.py index e1efecf5..7e21f5fc 100644 --- a/tests/unit/autojac/_transform/test_jac.py +++ b/tests/unit/autojac/_transform/test_jac.py @@ -7,7 +7,7 @@ @mark.parametrize("chunk_size", [1, 3, None]) -def test_single_input(chunk_size: int | None): +def test_single_input(chunk_size: int | None) -> None: """ Tests that the Jac transform works correctly for an example of multiple differentiation. Here, the function considered is: `y = [a1 * x, a2 * x]`. We want to compute the jacobians of `y` with @@ -32,7 +32,7 @@ def test_single_input(chunk_size: int | None): @mark.parametrize("chunk_size", [1, 3, None]) -def test_empty_inputs_1(chunk_size: int | None): +def test_empty_inputs_1(chunk_size: int | None) -> None: """ Tests that the Jac transform works correctly when the `inputs` parameter is an empty `Iterable`. """ @@ -51,7 +51,7 @@ def test_empty_inputs_1(chunk_size: int | None): @mark.parametrize("chunk_size", [1, 3, None]) -def test_empty_inputs_2(chunk_size: int | None): +def test_empty_inputs_2(chunk_size: int | None) -> None: """ Tests that the Jac transform works correctly when the `inputs` parameter is an empty `Iterable`. """ @@ -73,7 +73,7 @@ def test_empty_inputs_2(chunk_size: int | None): @mark.parametrize("chunk_size", [1, 3, None]) -def test_empty_outputs(chunk_size: int | None): +def test_empty_outputs(chunk_size: int | None) -> None: """ Tests that the Jac transform works correctly when the `outputs` parameter is an empty `Iterable`. @@ -94,7 +94,7 @@ def test_empty_outputs(chunk_size: int | None): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_retain_graph(): +def test_retain_graph() -> None: """Tests that the `Jac` transform behaves as expected with the `retain_graph` flag.""" x = tensor_(5.0) @@ -127,7 +127,7 @@ def test_retain_graph(): jac_discard_graph(input) -def test_two_levels(): +def test_two_levels() -> None: """ Tests that the Jac transform works correctly for an example of chained differentiation. Here, the function considered is: `z = a * x1 * x2`, which is computed in 2 parts: `y = a * x1` and @@ -167,7 +167,7 @@ def test_two_levels(): @mark.parametrize("chunk_size", [1, 3, None]) -def test_multiple_outputs_1(chunk_size: int | None): +def test_multiple_outputs_1(chunk_size: int | None) -> None: """ Tests that the Jac transform works correctly when the `outputs` contains 3 vectors. The input (jac_outputs) is not the same for all outputs, so that this test also checks that the @@ -201,7 +201,7 @@ def test_multiple_outputs_1(chunk_size: int | None): @mark.parametrize("chunk_size", [1, 3, None]) -def test_multiple_outputs_2(chunk_size: int | None): +def test_multiple_outputs_2(chunk_size: int | None) -> None: """ Same as test_multiple_outputs_1 but with different jac_outputs, so the returned jacobians are of different shapes. @@ -232,7 +232,7 @@ def test_multiple_outputs_2(chunk_size: int | None): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_composition_of_jacs_is_jac(): +def test_composition_of_jacs_is_jac() -> None: """ Tests that the composition of 2 Jac transforms is equivalent to computing the Jac directly in a single transform. @@ -268,7 +268,7 @@ def test_composition_of_jacs_is_jac(): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_conjunction_of_jacs_is_jac(): +def test_conjunction_of_jacs_is_jac() -> None: """ Tests that the conjunction of 2 Jac transforms is equivalent to computing the Jac directly in a single transform. @@ -294,7 +294,7 @@ def test_conjunction_of_jacs_is_jac(): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_create_graph(): +def test_create_graph() -> None: """Tests that the Jac transform behaves correctly when `create_graph` is set to `True`.""" x = tensor_(5.0) @@ -318,7 +318,7 @@ def test_create_graph(): assert jacobians[a2].requires_grad -def test_check_keys(): +def test_check_keys() -> None: """ Tests that the `check_keys` method works correctly: the input_keys should match the stored outputs. diff --git a/tests/unit/autojac/_transform/test_select.py b/tests/unit/autojac/_transform/test_select.py index 041eefc9..897ad817 100644 --- a/tests/unit/autojac/_transform/test_select.py +++ b/tests/unit/autojac/_transform/test_select.py @@ -6,7 +6,7 @@ from torchjd.autojac._transform import RequirementError, Select -def test_partition(): +def test_partition() -> None: """ Tests that the Select transform works correctly by applying 2 different Selects to a TensorDict, whose keys form a partition of the keys of the TensorDict. @@ -34,7 +34,7 @@ def test_partition(): assert_tensor_dicts_are_close(output2, expected_output2) -def test_conjunction_of_selects_is_select(): +def test_conjunction_of_selects_is_select() -> None: """ Tests that the conjunction of 2 Select transforms is equivalent to directly using a Select with the union of the keys of the 2 Selects. @@ -56,7 +56,7 @@ def test_conjunction_of_selects_is_select(): assert_tensor_dicts_are_close(output, expected_output) -def test_check_keys(): +def test_check_keys() -> None: """ Tests that the `check_keys` method works correctly: the set of keys to select should be a subset of the set of required_keys. diff --git a/tests/unit/autojac/_transform/test_stack.py b/tests/unit/autojac/_transform/test_stack.py index ef644fb5..28515fa9 100644 --- a/tests/unit/autojac/_transform/test_stack.py +++ b/tests/unit/autojac/_transform/test_stack.py @@ -22,7 +22,7 @@ def check_keys(self, _input_keys: set[Tensor], /) -> set[Tensor]: return self.keys -def test_single_key(): +def test_single_key() -> None: """ Tests that the Stack transform correctly stacks gradients into a jacobian, in a very simple example with 2 transforms sharing the same key. @@ -40,7 +40,7 @@ def test_single_key(): assert_tensor_dicts_are_close(output, expected_output) -def test_disjoint_key_sets(): +def test_disjoint_key_sets() -> None: """ Tests that the Stack transform correctly stacks gradients into a jacobian, in an example where the output key sets of all of its transforms are disjoint. The missing values should be replaced @@ -64,7 +64,7 @@ def test_disjoint_key_sets(): assert_tensor_dicts_are_close(output, expected_output) -def test_overlapping_key_sets(): +def test_overlapping_key_sets() -> None: """ Tests that the Stack transform correctly stacks gradients into a jacobian, in an example where the output key sets all of its transforms are overlapping (non-empty intersection, but not @@ -90,7 +90,7 @@ def test_overlapping_key_sets(): assert_tensor_dicts_are_close(output, expected_output) -def test_empty(): +def test_empty() -> None: """Tests that the Stack transform correctly handles an empty list of transforms.""" stack = Stack([]) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index a0398c42..edd2f197 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -9,7 +9,7 @@ @mark.parametrize("default_jac_tensors", [True, False]) -def test_check_create_transform(default_jac_tensors: bool): +def test_check_create_transform(default_jac_tensors: bool) -> None: """Tests that _create_transform creates a valid Transform.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -37,7 +37,7 @@ def test_check_create_transform(default_jac_tensors: bool): assert output_keys == set() -def test_jac_is_populated(): +def test_jac_is_populated() -> None: """Tests that backward correctly fills the .jac field.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -59,7 +59,7 @@ def test_value_is_correct( shape: tuple[int, int], manually_specify_inputs: bool, chunk_size: int | None, -): +) -> None: """ Tests that the .jac value filled by backward is correct in a simple example of matrix-vector product. @@ -81,7 +81,7 @@ def test_value_is_correct( @mark.parametrize("rows", [1, 2, 5]) -def test_jac_tensors_value_is_correct(rows: int): +def test_jac_tensors_value_is_correct(rows: int) -> None: """ Tests that backward correctly computes the product of jac_tensors and the Jacobian. result = jac_tensors @ Jacobian(tensors, inputs). @@ -107,7 +107,7 @@ def test_jac_tensors_value_is_correct(rows: int): @mark.parametrize("rows", [1, 3]) -def test_jac_tensors_multiple_components(rows: int): +def test_jac_tensors_multiple_components(rows: int) -> None: """ Tests that jac_tensors works correctly when tensors is a list of multiple tensors. The jac_tensors must match the structure of tensors. @@ -132,7 +132,7 @@ def test_jac_tensors_multiple_components(rows: int): assert_jac_close(input, expected) -def test_jac_tensors_length_mismatch(): +def test_jac_tensors_length_mismatch() -> None: """Tests that backward raises a ValueError early if len(jac_tensors) != len(tensors).""" x = tensor_([1.0, 2.0], requires_grad=True) y1 = x * 2 @@ -147,7 +147,7 @@ def test_jac_tensors_length_mismatch(): backward([y1, y2], jac_tensors=[J1], inputs=[x]) -def test_jac_tensors_shape_mismatch(): +def test_jac_tensors_shape_mismatch() -> None: """ Tests that backward raises a ValueError early if the shape of a tensor in jac_tensors is incompatible with the corresponding tensor. @@ -171,7 +171,7 @@ def test_jac_tensors_shape_mismatch(): (1, 2), ], ) -def test_jac_tensors_inconsistent_first_dimension(rows_y1: int, rows_y2: int): +def test_jac_tensors_inconsistent_first_dimension(rows_y1: int, rows_y2: int) -> None: """ Tests that backward raises a ValueError early when the provided jac_tensors have inconsistent first dimensions. @@ -190,7 +190,7 @@ def test_jac_tensors_inconsistent_first_dimension(rows_y1: int, rows_y2: int): backward([y1, y2], jac_tensors=[j1, j2], inputs=[x]) -def test_empty_inputs(): +def test_empty_inputs() -> None: """Tests that backward does not fill the .jac values if no input is specified.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -205,7 +205,7 @@ def test_empty_inputs(): assert_has_no_jac(a) -def test_partial_inputs(): +def test_partial_inputs() -> None: """ Tests that backward fills the right .jac values when only a subset of the actual inputs are specified as inputs. @@ -223,7 +223,7 @@ def test_partial_inputs(): assert_has_no_jac(a2) -def test_empty_tensors_fails(): +def test_empty_tensors_fails() -> None: """Tests that backward raises an error when called with an empty list of tensors.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -233,7 +233,7 @@ def test_empty_tensors_fails(): backward([], inputs=[a1, a2]) -def test_multiple_tensors(): +def test_multiple_tensors() -> None: """ Tests that giving multiple tensors to backward is equivalent to giving a single tensor containing all the values of the original tensors. @@ -268,7 +268,7 @@ def test_multiple_tensors(): @mark.parametrize("chunk_size", [None, 1, 2, 4]) -def test_various_valid_chunk_sizes(chunk_size): +def test_various_valid_chunk_sizes(chunk_size) -> None: """Tests that backward works for various valid values of parallel_chunk_size.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -284,7 +284,7 @@ def test_various_valid_chunk_sizes(chunk_size): @mark.parametrize("chunk_size", [0, -1]) -def test_non_positive_chunk_size_fails(chunk_size: int): +def test_non_positive_chunk_size_fails(chunk_size: int) -> None: """Tests that backward raises an error when using invalid chunk sizes.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -297,7 +297,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int): backward([y1, y2], parallel_chunk_size=chunk_size) -def test_input_retaining_grad_fails(): +def test_input_retaining_grad_fails() -> None: """ Tests that backward raises an error when some input in the computation graph of the ``tensors`` parameter retains grad and vmap has to be used. @@ -317,7 +317,7 @@ def test_input_retaining_grad_fails(): _ = -b.grad # type: ignore[unsupported-operator] -def test_non_input_retaining_grad_fails(): +def test_non_input_retaining_grad_fails() -> None: """ Tests that backward fails to fill a valid `.grad` when some tensor in the computation graph of the ``tensors`` parameter retains grad and vmap has to be used. @@ -337,7 +337,7 @@ def test_non_input_retaining_grad_fails(): @mark.parametrize("chunk_size", [1, 3, None]) -def test_tensor_used_multiple_times(chunk_size: int | None): +def test_tensor_used_multiple_times(chunk_size: int | None) -> None: """ Tests that backward works correctly when one of the inputs is used multiple times. In this setup, the autograd graph is still acyclic, but the graph of tensors used becomes cyclic. @@ -356,7 +356,7 @@ def test_tensor_used_multiple_times(chunk_size: int | None): assert_jac_close(a, J) -def test_repeated_tensors(): +def test_repeated_tensors() -> None: """ Tests that backward does not allow repeating tensors. @@ -375,7 +375,7 @@ def test_repeated_tensors(): backward([y1, y1, y2]) -def test_repeated_inputs(): +def test_repeated_inputs() -> None: """ Tests that backward correctly works when some inputs are repeated. In this case, since torch.autograd.backward ignores the repetition of the inputs, it is natural for autojac to diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 3ee6561f..c880eacc 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -9,7 +9,7 @@ @mark.parametrize("default_jac_outputs", [True, False]) -def test_check_create_transform(default_jac_outputs: bool): +def test_check_create_transform(default_jac_outputs: bool) -> None: """Tests that _create_transform creates a valid Transform.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -37,7 +37,7 @@ def test_check_create_transform(default_jac_outputs: bool): assert output_keys == {a1, a2} -def test_jac(): +def test_jac() -> None: """Tests that jac works.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -65,7 +65,7 @@ def test_value_is_correct( chunk_size: int | None, outputs_is_list: bool, inputs_is_list: bool, -): +) -> None: """ Tests that the jacobians returned by jac are correct in a simple example of matrix-vector product. @@ -85,7 +85,7 @@ def test_value_is_correct( @mark.parametrize("rows", [1, 2, 5]) -def test_jac_outputs_value_is_correct(rows: int): +def test_jac_outputs_value_is_correct(rows: int) -> None: """ Tests that jac correctly computes the product of jac_outputs and the Jacobian. result = jac_outputs @ Jacobian(outputs, inputs). @@ -111,7 +111,7 @@ def test_jac_outputs_value_is_correct(rows: int): @mark.parametrize("rows", [1, 3]) -def test_jac_outputs_multiple_components(rows: int): +def test_jac_outputs_multiple_components(rows: int) -> None: """ Tests that jac_outputs works correctly when outputs is a list of multiple tensors. The jac_outputs must match the structure of outputs. @@ -136,7 +136,7 @@ def test_jac_outputs_multiple_components(rows: int): assert_close(jacobians[0], expected) -def test_jac_outputs_length_mismatch(): +def test_jac_outputs_length_mismatch() -> None: """Tests that jac raises a ValueError early if len(jac_outputs) != len(outputs).""" x = tensor_([1.0, 2.0], requires_grad=True) y1 = x * 2 @@ -151,7 +151,7 @@ def test_jac_outputs_length_mismatch(): jac([y1, y2], x, jac_outputs=[J1]) -def test_jac_outputs_shape_mismatch(): +def test_jac_outputs_shape_mismatch() -> None: """ Tests that jac raises a ValueError early if the shape of a tensor in jac_outputs is incompatible with the corresponding output tensor. @@ -175,7 +175,7 @@ def test_jac_outputs_shape_mismatch(): (1, 2), ], ) -def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int): +def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int) -> None: """ Tests that jac raises a ValueError early when the provided jac_outputs have inconsistent first dimensions. @@ -194,7 +194,7 @@ def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int): jac([y1, y2], x, jac_outputs=[j1, j2]) -def test_empty_inputs(): +def test_empty_inputs() -> None: """Tests that jac does not return any jacobian no input is specified.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -207,7 +207,7 @@ def test_empty_inputs(): assert len(jacobians) == 0 -def test_partial_inputs(): +def test_partial_inputs() -> None: """ Tests that jac returns the right jacobians when only a subset of the actual inputs are specified as inputs. @@ -223,7 +223,7 @@ def test_partial_inputs(): assert len(jacobians) == 1 -def test_empty_tensors_fails(): +def test_empty_tensors_fails() -> None: """Tests that jac raises an error when called with an empty list of tensors.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -233,7 +233,7 @@ def test_empty_tensors_fails(): jac([], inputs=[a1, a2]) -def test_multiple_tensors(): +def test_multiple_tensors() -> None: """ Tests that giving multiple tensors to jac is equivalent to giving a single tensor containing all the values of the original tensors. @@ -268,7 +268,7 @@ def test_multiple_tensors(): @mark.parametrize("chunk_size", [None, 1, 2, 4]) -def test_various_valid_chunk_sizes(chunk_size): +def test_various_valid_chunk_sizes(chunk_size) -> None: """Tests that jac works for various valid values of parallel_chunk_size.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -282,7 +282,7 @@ def test_various_valid_chunk_sizes(chunk_size): @mark.parametrize("chunk_size", [0, -1]) -def test_non_positive_chunk_size_fails(chunk_size: int): +def test_non_positive_chunk_size_fails(chunk_size: int) -> None: """Tests that jac raises an error when using invalid chunk sizes.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -295,7 +295,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int): jac([y1, y2], [a1, a2], parallel_chunk_size=chunk_size) -def test_input_retaining_grad_fails(): +def test_input_retaining_grad_fails() -> None: """ Tests that jac raises an error when some input in the computation graph of the ``tensors`` parameter retains grad and vmap has to be used. @@ -315,7 +315,7 @@ def test_input_retaining_grad_fails(): _ = -b.grad # type: ignore[unsupported-operator] -def test_non_input_retaining_grad_fails(): +def test_non_input_retaining_grad_fails() -> None: """ Tests that jac fails to fill a valid `.grad` when some tensor in the computation graph of the ``tensors`` parameter retains grad and vmap has to be used. @@ -335,7 +335,7 @@ def test_non_input_retaining_grad_fails(): @mark.parametrize("chunk_size", [1, 3, None]) -def test_tensor_used_multiple_times(chunk_size: int | None): +def test_tensor_used_multiple_times(chunk_size: int | None) -> None: """ Tests that jac works correctly when one of the inputs is used multiple times. In this setup, the autograd graph is still acyclic, but the graph of tensors used becomes cyclic. @@ -355,7 +355,7 @@ def test_tensor_used_multiple_times(chunk_size: int | None): assert_close(jacobians[0], J) -def test_repeated_tensors(): +def test_repeated_tensors() -> None: """ Tests that jac does not allow repeating tensors. @@ -374,7 +374,7 @@ def test_repeated_tensors(): jac([y1, y1, y2], [a1, a2]) -def test_repeated_inputs(): +def test_repeated_inputs() -> None: """ Tests that jac correctly works when some inputs are repeated. In this case, since torch.autograd.grad repeats the output gradients, it is natural for autojac to also repeat the diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index a3f83097..2e9dca4b 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -7,7 +7,7 @@ @mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()]) -def test_various_aggregators(aggregator: Aggregator): +def test_various_aggregators(aggregator: Aggregator) -> None: """Tests that jac_to_grad works for various aggregators.""" t1 = tensor_(1.0, requires_grad=True) @@ -25,7 +25,7 @@ def test_various_aggregators(aggregator: Aggregator): assert_grad_close(t2, g2) -def test_single_tensor(): +def test_single_tensor() -> None: """Tests that jac_to_grad works when a single tensor is provided.""" aggregator = UPGrad() @@ -39,7 +39,7 @@ def test_single_tensor(): assert_grad_close(t, g) -def test_no_jac_field(): +def test_no_jac_field() -> None: """Tests that jac_to_grad fails when a tensor does not have a jac field.""" aggregator = UPGrad() @@ -52,7 +52,7 @@ def test_no_jac_field(): jac_to_grad([t1, t2], aggregator) -def test_no_requires_grad(): +def test_no_requires_grad() -> None: """Tests that jac_to_grad fails when a tensor does not require grad.""" aggregator = UPGrad() @@ -66,7 +66,7 @@ def test_no_requires_grad(): jac_to_grad([t1, t2], aggregator) -def test_row_mismatch(): +def test_row_mismatch() -> None: """Tests that jac_to_grad fails when the number of rows of the .jac is not constant.""" aggregator = UPGrad() @@ -79,14 +79,14 @@ def test_row_mismatch(): jac_to_grad([t1, t2], aggregator) -def test_no_tensors(): +def test_no_tensors() -> None: """Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided.""" jac_to_grad([], aggregator=UPGrad()) @mark.parametrize("retain_jac", [True, False]) -def test_jacs_are_freed(retain_jac: bool): +def test_jacs_are_freed(retain_jac: bool) -> None: """Tests that jac_to_grad frees the jac fields if an only if retain_jac is False.""" aggregator = UPGrad() diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 9f15e6fd..6de8d7aa 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -18,7 +18,7 @@ @mark.parametrize("default_grad_tensors", [True, False]) -def test_check_create_transform(default_grad_tensors: bool): +def test_check_create_transform(default_grad_tensors: bool) -> None: """Tests that _create_transform creates a valid Transform.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -48,7 +48,7 @@ def test_check_create_transform(default_grad_tensors: bool): assert output_keys == set() -def test_shape_is_correct(): +def test_shape_is_correct() -> None: """Tests that mtl_backward works correctly.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -76,7 +76,7 @@ def test_value_is_correct( manually_specify_shared_params: bool, manually_specify_tasks_params: bool, chunk_size: int | None, -): +) -> None: """ Tests that the .jac value filled by mtl_backward is correct in a simple example of matrix-vector product for three tasks whose loss are given by a simple inner product of the @@ -116,7 +116,7 @@ def test_value_is_correct( assert_jac_close(p0, expected_jacobian) -def test_empty_tasks_fails(): +def test_empty_tasks_fails() -> None: """Tests that mtl_backward raises an error when called with an empty list of tasks.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -128,7 +128,7 @@ def test_empty_tasks_fails(): mtl_backward([], features=[f1, f2]) -def test_single_task(): +def test_single_task() -> None: """Tests that mtl_backward works correctly with a single task.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -144,7 +144,7 @@ def test_single_task(): assert_has_grad(p1) -def test_incoherent_task_number_fails(): +def test_incoherent_task_number_fails() -> None: """ Tests that mtl_backward raises an error when called with the number of tasks losses different from the number of tasks parameters. @@ -175,7 +175,7 @@ def test_incoherent_task_number_fails(): ) -def test_empty_params(): +def test_empty_params() -> None: """Tests that mtl_backward does not fill the .jac/.grad values if no parameter is specified.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -199,7 +199,7 @@ def test_empty_params(): assert_has_no_grad(p) -def test_multiple_params_per_task(): +def test_multiple_params_per_task() -> None: """Tests that mtl_backward works correctly when the tasks each have several parameters.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -234,7 +234,7 @@ def test_multiple_params_per_task(): [(5, 4, 3, 2), (5, 4, 3, 2)], ], ) -def test_various_shared_params(shared_params_shapes: list[tuple[int]]): +def test_various_shared_params(shared_params_shapes: list[tuple[int]]) -> None: """Tests that mtl_backward works correctly with various kinds of shared_params.""" shared_params = [rand_(shape, requires_grad=True) for shape in shared_params_shapes] @@ -258,7 +258,7 @@ def test_various_shared_params(shared_params_shapes: list[tuple[int]]): assert_has_grad(p) -def test_partial_params(): +def test_partial_params() -> None: """ Tests that mtl_backward fills the right .jac/.grad values when only a subset of the parameters are specified as inputs. @@ -285,7 +285,7 @@ def test_partial_params(): assert_has_no_grad(p2) -def test_empty_features_fails(): +def test_empty_features_fails() -> None: """Tests that mtl_backward expectedly raises an error when no there is no feature.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -310,7 +310,7 @@ def test_empty_features_fails(): (5, 4, 3, 2), ], ) -def test_various_single_features(shape: tuple[int, ...]): +def test_various_single_features(shape: tuple[int, ...]) -> None: """Tests that mtl_backward works correctly with various kinds of feature tensors.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -342,7 +342,7 @@ def test_various_single_features(shape: tuple[int, ...]): [(5, 4, 3, 2), (5, 4, 3, 2)], ], ) -def test_various_feature_lists(shapes: list[tuple[int]]): +def test_various_feature_lists(shapes: list[tuple[int]]) -> None: """Tests that mtl_backward works correctly with various kinds of feature lists.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -361,7 +361,7 @@ def test_various_feature_lists(shapes: list[tuple[int]]): assert_has_grad(p) -def test_non_scalar_loss_fails(): +def test_non_scalar_loss_fails() -> None: """Tests that mtl_backward raises an error when used with a non-scalar loss.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -378,7 +378,7 @@ def test_non_scalar_loss_fails(): @mark.parametrize("chunk_size", [None, 1, 2, 4]) -def test_various_valid_chunk_sizes(chunk_size): +def test_various_valid_chunk_sizes(chunk_size) -> None: """Tests that mtl_backward works for various valid values of parallel_chunk_size.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -402,7 +402,7 @@ def test_various_valid_chunk_sizes(chunk_size): @mark.parametrize("chunk_size", [0, -1]) -def test_non_positive_chunk_size_fails(chunk_size: int): +def test_non_positive_chunk_size_fails(chunk_size: int) -> None: """Tests that mtl_backward raises an error when using invalid chunk sizes.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -422,7 +422,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int): ) -def test_shared_param_retaining_grad_fails(): +def test_shared_param_retaining_grad_fails() -> None: """ Tests that mtl_backward fails to fill a valid `.grad` when some shared param in the computation graph of the ``features`` parameter retains grad and vmap has to be used. @@ -451,7 +451,7 @@ def test_shared_param_retaining_grad_fails(): _ = -a.grad # type: ignore[unsupported-operator] -def test_shared_activation_retaining_grad_fails(): +def test_shared_activation_retaining_grad_fails() -> None: """ Tests that mtl_backward fails to fill a valid `.grad` when some tensor in the computation graph of the ``features`` parameter retains grad and vmap has to be used. @@ -480,7 +480,7 @@ def test_shared_activation_retaining_grad_fails(): _ = -a.grad # type: ignore[unsupported-operator] -def test_tasks_params_overlap(): +def test_tasks_params_overlap() -> None: """Tests that mtl_backward works correctly when the tasks' parameters have some overlap.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -502,7 +502,7 @@ def test_tasks_params_overlap(): assert_jac_close(p0, J) -def test_tasks_params_are_the_same(): +def test_tasks_params_are_the_same() -> None: """Tests that mtl_backward works correctly when the tasks have the same params.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -520,7 +520,7 @@ def test_tasks_params_are_the_same(): assert_jac_close(p0, J) -def test_task_params_is_subset_of_other_task_params(): +def test_task_params_is_subset_of_other_task_params() -> None: """ Tests that mtl_backward works correctly when one task's params is a subset of another task's params. @@ -543,7 +543,7 @@ def test_task_params_is_subset_of_other_task_params(): assert_jac_close(p0, J) -def test_shared_params_overlapping_with_tasks_params_fails(): +def test_shared_params_overlapping_with_tasks_params_fails() -> None: """ Tests that mtl_backward raises an error when the set of shared params overlaps with the set of task-specific params. @@ -566,7 +566,7 @@ def test_shared_params_overlapping_with_tasks_params_fails(): ) -def test_default_shared_params_overlapping_with_default_tasks_params_fails(): +def test_default_shared_params_overlapping_with_default_tasks_params_fails() -> None: """ Tests that mtl_backward raises an error when the set of shared params obtained by default overlaps with the set of task-specific params obtained by default. @@ -587,7 +587,7 @@ def test_default_shared_params_overlapping_with_default_tasks_params_fails(): ) -def test_repeated_losses(): +def test_repeated_losses() -> None: """ Tests that mtl_backward does not allow repeating losses. @@ -610,7 +610,7 @@ def test_repeated_losses(): mtl_backward(losses, features=[f1, f2], retain_graph=True) -def test_repeated_features(): +def test_repeated_features() -> None: """ Tests that mtl_backward does not allow repeating features. @@ -633,7 +633,7 @@ def test_repeated_features(): mtl_backward([y1, y2], features=features) -def test_repeated_shared_params(): +def test_repeated_shared_params() -> None: """ Tests that mtl_backward correctly works when some shared are repeated. Since these are tensors with respect to which we differentiate, to match the behavior of torch.autograd.backward, this @@ -661,7 +661,7 @@ def test_repeated_shared_params(): assert_grad_close(p2, g2) -def test_repeated_task_params(): +def test_repeated_task_params() -> None: """ Tests that mtl_backward correctly works when some task-specific params are repeated for some task. Since these are tensors with respect to which we differentiate, to match the behavior of @@ -689,7 +689,7 @@ def test_repeated_task_params(): assert_grad_close(p2, g2) -def test_grad_tensors_value_is_correct(): +def test_grad_tensors_value_is_correct() -> None: """ Tests that mtl_ackward correctly computes the element-wise product of grad_tensors and the tensors. @@ -724,7 +724,7 @@ def test_grad_tensors_value_is_correct(): assert_jac_close(p0, expected_jacobian) -def test_grad_tensors_length_mismatch(): +def test_grad_tensors_length_mismatch() -> None: """Tests that mtl_backward raises a ValueError early if len(grad_tensors) != len(tensors).""" p0 = randn_(3, requires_grad=True) @@ -747,7 +747,7 @@ def test_grad_tensors_length_mismatch(): ) -def test_grad_tensors_shape_mismatch(): +def test_grad_tensors_shape_mismatch() -> None: """ Tests that mtl_backward raises a ValueError early if the shape of a tensor in grad_tensors is incompatible with the corresponding tensor. diff --git a/tests/unit/autojac/test_utils.py b/tests/unit/autojac/test_utils.py index f4dbf7a4..bb1450d8 100644 --- a/tests/unit/autojac/test_utils.py +++ b/tests/unit/autojac/test_utils.py @@ -6,7 +6,7 @@ from torchjd.autojac._utils import get_leaf_tensors -def test_simple_get_leaf_tensors(): +def test_simple_get_leaf_tensors() -> None: """Tests that _get_leaf_tensors works correctly in a very simple setting.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -19,7 +19,7 @@ def test_simple_get_leaf_tensors(): assert set(leaves) == {a1, a2} -def test_get_leaf_tensors_excluded_1(): +def test_get_leaf_tensors_excluded_1() -> None: """ Tests that _get_leaf_tensors works correctly when some tensors are excluded from the search. @@ -40,7 +40,7 @@ def test_get_leaf_tensors_excluded_1(): assert set(leaves) == {a1} -def test_get_leaf_tensors_excluded_2(): +def test_get_leaf_tensors_excluded_2() -> None: """ Tests that _get_leaf_tensors works correctly when some tensors are excluded from the search. @@ -61,7 +61,7 @@ def test_get_leaf_tensors_excluded_2(): assert set(leaves) == {a1, a2} -def test_get_leaf_tensors_leaf_not_requiring_grad(): +def test_get_leaf_tensors_leaf_not_requiring_grad() -> None: """ Tests that _get_leaf_tensors does not include tensors that do not require grad in its results. """ @@ -76,7 +76,7 @@ def test_get_leaf_tensors_leaf_not_requiring_grad(): assert set(leaves) == {a1} -def test_get_leaf_tensors_model(): +def test_get_leaf_tensors_model() -> None: """ Tests that _get_leaf_tensors works correctly when the autograd graph is generated by a simple sequential model. @@ -95,7 +95,7 @@ def test_get_leaf_tensors_model(): assert set(leaves) == set(model.parameters()) -def test_get_leaf_tensors_model_excluded_2(): +def test_get_leaf_tensors_model_excluded_2() -> None: """ Tests that _get_leaf_tensors works correctly when the autograd graph is generated by a simple sequential model, and some intermediate values are excluded. @@ -116,7 +116,7 @@ def test_get_leaf_tensors_model_excluded_2(): assert set(leaves) == set(model2.parameters()) -def test_get_leaf_tensors_single_root(): +def test_get_leaf_tensors_single_root() -> None: """Tests that _get_leaf_tensors returns no leaves when roots is the empty set.""" p = tensor_([1.0, 2.0], requires_grad=True) @@ -126,14 +126,14 @@ def test_get_leaf_tensors_single_root(): assert set(leaves) == {p} -def test_get_leaf_tensors_empty_roots(): +def test_get_leaf_tensors_empty_roots() -> None: """Tests that _get_leaf_tensors returns no leaves when roots is the empty set.""" leaves = get_leaf_tensors(tensors=[], excluded=set()) assert set(leaves) == set() -def test_get_leaf_tensors_excluded_root(): +def test_get_leaf_tensors_excluded_root() -> None: """Tests that _get_leaf_tensors correctly excludes the root.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -147,7 +147,7 @@ def test_get_leaf_tensors_excluded_root(): @mark.parametrize("depth", [100, 1000, 10000]) -def test_get_leaf_tensors_deep(depth: int): +def test_get_leaf_tensors_deep(depth: int) -> None: """Tests that _get_leaf_tensors works when the graph is very deep.""" one = tensor_(1.0, requires_grad=True) @@ -159,7 +159,7 @@ def test_get_leaf_tensors_deep(depth: int): assert set(leaves) == {one} -def test_get_leaf_tensors_leaf(): +def test_get_leaf_tensors_leaf() -> None: """Tests that _get_leaf_tensors raises an error some of the provided tensors are leaves.""" a = tensor_(1.0, requires_grad=True) @@ -167,7 +167,7 @@ def test_get_leaf_tensors_leaf(): _ = get_leaf_tensors(tensors=[a], excluded=set()) -def test_get_leaf_tensors_tensor_not_requiring_grad(): +def test_get_leaf_tensors_tensor_not_requiring_grad() -> None: """ Tests that _get_leaf_tensors raises an error some of the provided tensors do not require grad. """ @@ -177,7 +177,7 @@ def test_get_leaf_tensors_tensor_not_requiring_grad(): _ = get_leaf_tensors(tensors=[a], excluded=set()) -def test_get_leaf_tensors_excluded_leaf(): +def test_get_leaf_tensors_excluded_leaf() -> None: """Tests that _get_leaf_tensors raises an error some of the excluded tensors are leaves.""" a = tensor_(1.0, requires_grad=True) * 2 @@ -186,7 +186,7 @@ def test_get_leaf_tensors_excluded_leaf(): _ = get_leaf_tensors(tensors=[a], excluded={b}) -def test_get_leaf_tensors_excluded_not_requiring_grad(): +def test_get_leaf_tensors_excluded_not_requiring_grad() -> None: """ Tests that _get_leaf_tensors raises an error some of the excluded tensors do not require grad. """ diff --git a/tests/unit/linalg/test_gramian.py b/tests/unit/linalg/test_gramian.py index 53373822..2c8c1eec 100644 --- a/tests/unit/linalg/test_gramian.py +++ b/tests/unit/linalg/test_gramian.py @@ -20,13 +20,13 @@ [6, 7, 9], ], ) -def test_gramian_is_psd(shape: list[int]): +def test_gramian_is_psd(shape: list[int]) -> None: matrix = randn_(shape) gramian = compute_gramian(matrix) assert_is_psd_matrix(gramian) -def test_compute_gramian_scalar_input_0(): +def test_compute_gramian_scalar_input_0() -> None: t = tensor_(5.0) gramian = compute_gramian(t, contracted_dims=0) expected = tensor_(25.0) @@ -34,7 +34,7 @@ def test_compute_gramian_scalar_input_0(): assert_close(gramian, expected) -def test_compute_gramian_vector_input_0(): +def test_compute_gramian_vector_input_0() -> None: t = tensor_([2.0, 3.0]) gramian = compute_gramian(t, contracted_dims=0) expected = tensor_([[4.0, 6.0], [6.0, 9.0]]) @@ -42,7 +42,7 @@ def test_compute_gramian_vector_input_0(): assert_close(gramian, expected) -def test_compute_gramian_vector_input_1(): +def test_compute_gramian_vector_input_1() -> None: t = tensor_([2.0, 3.0]) gramian = compute_gramian(t, contracted_dims=1) expected = tensor_(13.0) @@ -50,7 +50,7 @@ def test_compute_gramian_vector_input_1(): assert_close(gramian, expected) -def test_compute_gramian_matrix_input_0(): +def test_compute_gramian_matrix_input_0() -> None: t = tensor_([[1.0, 2.0], [3.0, 4.0]]) gramian = compute_gramian(t, contracted_dims=0) expected = tensor_( @@ -63,7 +63,7 @@ def test_compute_gramian_matrix_input_0(): assert_close(gramian, expected) -def test_compute_gramian_matrix_input_1(): +def test_compute_gramian_matrix_input_1() -> None: t = tensor_([[1.0, 2.0], [3.0, 4.0]]) gramian = compute_gramian(t, contracted_dims=1) expected = tensor_([[5.0, 11.0], [11.0, 25.0]]) @@ -71,7 +71,7 @@ def test_compute_gramian_matrix_input_1(): assert_close(gramian, expected) -def test_compute_gramian_matrix_input_2(): +def test_compute_gramian_matrix_input_2() -> None: t = tensor_([[1.0, 2.0], [3.0, 4.0]]) gramian = compute_gramian(t, contracted_dims=2) expected = tensor_(30.0) @@ -89,7 +89,7 @@ def test_compute_gramian_matrix_input_2(): [5, 0], ], ) -def test_normalize_yields_psd(shape: list[int]): +def test_normalize_yields_psd(shape: list[int]) -> None: matrix = randn_(shape) assert is_matrix(matrix) gramian = compute_gramian(matrix) @@ -107,7 +107,7 @@ def test_normalize_yields_psd(shape: list[int]): [5, 0], ], ) -def test_regularize_yields_psd(shape: list[int]): +def test_regularize_yields_psd(shape: list[int]) -> None: matrix = randn_(shape) assert is_matrix(matrix) gramian = compute_gramian(matrix) diff --git a/tests/unit/test_deprecations.py b/tests/unit/test_deprecations.py index f1121478..7d38a04d 100644 --- a/tests/unit/test_deprecations.py +++ b/tests/unit/test_deprecations.py @@ -2,7 +2,7 @@ # deprecated since 2025-08-18 -def test_deprecate_imports_from_torchjd(): +def test_deprecate_imports_from_torchjd() -> None: with pytest.deprecated_call(): from torchjd import backward # noqa: F401 diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 45f1417d..df2c24a5 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -201,7 +201,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False # don't suppress exceptions - def _restore_original_params(self, module: nn.Module): + def _restore_original_params(self, module: nn.Module) -> None: original_params = self._module_to_original_params.pop(module, {}) for name, param in original_params.items(): self._set_module_param(module, name, param) From f0330c1694e3052918fdba08412dcfff3fbec56a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:39:04 +0100 Subject: [PATCH 06/22] Add missing -> str --- tests/unit/autojac/_transform/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autojac/_transform/test_base.py b/tests/unit/autojac/_transform/test_base.py index a04dd866..d0d7f96d 100644 --- a/tests/unit/autojac/_transform/test_base.py +++ b/tests/unit/autojac/_transform/test_base.py @@ -14,7 +14,7 @@ def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]) -> None self._required_keys = required_keys self._output_keys = output_keys - def __str__(self): + def __str__(self) -> str: return "T" def __call__(self, _input: TensorDict, /) -> TensorDict: From ecc011b55b32c35271e5cfe79d5bf0dc11391332 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:40:38 +0100 Subject: [PATCH 07/22] Annotate obj as object in docs/source/conf.py --- docs/source/conf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index f0daa9b2..ecad14af 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -100,7 +100,7 @@ def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None: return link -def _get_obj(_info: dict[str, str]): +def _get_obj(_info: dict[str, str]) -> object: module_name = _info["module"] full_name = _info["fullname"] sub_module = sys.modules.get(module_name) @@ -112,7 +112,7 @@ def _get_obj(_info: dict[str, str]): return obj -def _get_file_name(obj) -> str | None: +def _get_file_name(obj: object) -> str | None: try: file_name = inspect.getsourcefile(obj) file_name = os.path.relpath(file_name, start=_PATH_ROOT) @@ -121,7 +121,7 @@ def _get_file_name(obj) -> str | None: return file_name -def _get_line_str(obj) -> str: +def _get_line_str(obj: object) -> str: source, start = inspect.getsourcelines(obj) end = start + len(source) - 1 line_str = f"#L{start}-L{end}" From 3c022e7557d8c667df03f751d6c87af07cae91d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:44:52 +0100 Subject: [PATCH 08/22] Add missing type annotations in conftest.py --- tests/conftest.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 478c48cf..e2b75d67 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import warnings from contextlib import nullcontext +import pytest import torch from pytest import RaisesExc, fixture, mark from settings import DEVICE @@ -30,16 +31,16 @@ def fix_randomness() -> None: torch.use_deterministic_algorithms(True) -def pytest_addoption(parser) -> None: +def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") -def pytest_configure(config) -> None: +def pytest_configure(config: pytest.Config) -> None: config.addinivalue_line("markers", "slow: mark test as slow to run") config.addinivalue_line("markers", "xfail_if_cuda: mark test as xfail if running on cuda") -def pytest_collection_modifyitems(config, items) -> None: +def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: skip_slow = mark.skip(reason="Slow test. Use --runslow to run it.") xfail_cuda = mark.xfail(reason=f"Test expected to fail on {DEVICE}") for item in items: @@ -49,7 +50,7 @@ def pytest_collection_modifyitems(config, items) -> None: item.add_marker(xfail_cuda) -def pytest_make_parametrize_id(config, val, argname): +def pytest_make_parametrize_id(config: pytest.Config, val: object, argname: str) -> str | None: MAX_SIZE = 40 optional_string = None # Returning None means using pytest's way of making the string From 0d7ea2e086fd51faa3b91e0ef15751936816180c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:47:05 +0100 Subject: [PATCH 09/22] Add type annotations in _make_tensors --- tests/utils/tensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/tensors.py b/tests/utils/tensors.py index 7988157d..8febba42 100644 --- a/tests/utils/tensors.py +++ b/tests/utils/tensors.py @@ -37,7 +37,7 @@ def make_inputs_and_targets(model: nn.Module, batch_size: int) -> tuple[PyTree, def _make_tensors(batch_size: int, tensor_shapes: PyTree) -> PyTree: - def is_leaf(s): + def is_leaf(s: PyTree) -> bool: return isinstance(s, tuple) and all(isinstance(e, int) for e in s) return tree_map(lambda s: randn_((batch_size, *s)), tensor_shapes, is_leaf=is_leaf) From 0d88143a964e8440e01718071173432fbf9f3cba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:48:44 +0100 Subject: [PATCH 10/22] Add type annotations to CloneParams --- tests/utils/forward_backwards.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index df2c24a5..008fbad4 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from types import TracebackType import torch from torch import Tensor, nn, vmap @@ -192,7 +193,12 @@ def post_hook(module: nn.Module, _, __) -> None: return self.clones - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: """Remove hooks and restore parameters.""" for handle in self._handles: handle.remove() From eb182c48dc9d301a0ae8493027352292adb9b44f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:49:40 +0100 Subject: [PATCH 11/22] Add type annotation to InterModuleParamReuse.forward --- tests/utils/architectures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index d83917cd..15d4ece3 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -456,7 +456,7 @@ def __init__(self, matrix: nn.Parameter) -> None: super().__init__() self.matrix = matrix - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return input @ self.matrix def __init__(self) -> None: From e11de010a2355cc9d8b5236b389fc05b0773146a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:51:17 +0100 Subject: [PATCH 12/22] Add -> None to __init_subclass__ --- tests/utils/architectures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 15d4ece3..d324ebff 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -36,7 +36,7 @@ class ShapedModule(nn.Module): INPUT_SHAPES: PyTree # meant to be overridden OUTPUT_SHAPES: PyTree # meant to be overridden - def __init_subclass__(cls): + def __init_subclass__(cls) -> None: super().__init_subclass__() if getattr(cls, "INPUT_SHAPES", None) is None: raise TypeError(f"{cls.__name__} must define INPUT_SHAPES") From aacf6718d4b7d044ba4d619ed49137799c73a0e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:52:17 +0100 Subject: [PATCH 13/22] Add missing type annotation for chunk_size --- tests/unit/autojac/test_backward.py | 2 +- tests/unit/autojac/test_jac.py | 2 +- tests/unit/autojac/test_mtl_backward.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index edd2f197..806eb545 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -268,7 +268,7 @@ def test_multiple_tensors() -> None: @mark.parametrize("chunk_size", [None, 1, 2, 4]) -def test_various_valid_chunk_sizes(chunk_size) -> None: +def test_various_valid_chunk_sizes(chunk_size: int | None) -> None: """Tests that backward works for various valid values of parallel_chunk_size.""" a1 = tensor_([1.0, 2.0], requires_grad=True) diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index c880eacc..75c68cb9 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -268,7 +268,7 @@ def test_multiple_tensors() -> None: @mark.parametrize("chunk_size", [None, 1, 2, 4]) -def test_various_valid_chunk_sizes(chunk_size) -> None: +def test_various_valid_chunk_sizes(chunk_size: int | None) -> None: """Tests that jac works for various valid values of parallel_chunk_size.""" a1 = tensor_([1.0, 2.0], requires_grad=True) diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 6de8d7aa..5ca13882 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -378,7 +378,7 @@ def test_non_scalar_loss_fails() -> None: @mark.parametrize("chunk_size", [None, 1, 2, 4]) -def test_various_valid_chunk_sizes(chunk_size) -> None: +def test_various_valid_chunk_sizes(chunk_size: int | None) -> None: """Tests that mtl_backward works for various valid values of parallel_chunk_size.""" p0 = tensor_([1.0, 2.0], requires_grad=True) From 4c2b730101c378ea4fb9ebd31c0210210aca98d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:55:33 +0100 Subject: [PATCH 14/22] Add type annotations to time_call --- tests/profiling/speed_grad_vs_jac_vs_gram.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/profiling/speed_grad_vs_jac_vs_gram.py b/tests/profiling/speed_grad_vs_jac_vs_gram.py index b4f9825d..b670a96d 100644 --- a/tests/profiling/speed_grad_vs_jac_vs_gram.py +++ b/tests/profiling/speed_grad_vs_jac_vs_gram.py @@ -1,5 +1,6 @@ import gc import time +from collections.abc import Callable import torch from settings import DEVICE @@ -125,7 +126,13 @@ def noop() -> None: pass -def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> Tensor: +def time_call( + fn: Callable[[], None], + init_fn: Callable[[], None] = noop, + pre_fn: Callable[[], None] = noop, + post_fn: Callable[[], None] = noop, + n_runs: int = 10, +) -> Tensor: init_fn() times = [] From 2f53b64c1f40490ecc9894186150903b2dc897bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:58:21 +0100 Subject: [PATCH 15/22] Add type annotations in run_profiler.py --- tests/profiling/run_profiler.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/profiling/run_profiler.py b/tests/profiling/run_profiler.py index ba793610..9807849b 100644 --- a/tests/profiling/run_profiler.py +++ b/tests/profiling/run_profiler.py @@ -3,7 +3,9 @@ import torch from settings import DEVICE +from torch import Tensor, nn from torch.profiler import ProfilerActivity, profile +from torch.utils._pytree import PyTree from utils.architectures import ( AlexNet, Cifar10Model, @@ -105,7 +107,9 @@ def _save_and_print_trace( def profile_autojac(factory: ModuleFactory, batch_size: int) -> None: - def forward_backward_fn(model, inputs, loss_fn) -> None: + def forward_backward_fn( + model: nn.Module, inputs: PyTree, loss_fn: Callable[[PyTree], list[Tensor]] + ) -> None: aggregator = UPGrad() autojac_forward_backward(model, inputs, loss_fn, aggregator) @@ -113,7 +117,9 @@ def forward_backward_fn(model, inputs, loss_fn) -> None: def profile_autogram(factory: ModuleFactory, batch_size: int) -> None: - def forward_backward_fn(model, inputs, loss_fn) -> None: + def forward_backward_fn( + model: nn.Module, inputs: PyTree, loss_fn: Callable[[PyTree], list[Tensor]] + ) -> None: engine = Engine(model, batch_dim=0) weighting = UPGradWeighting() autogram_forward_backward(model, inputs, loss_fn, engine, weighting) From 02f4980ebe64f906dc6f6d6c88b7f087a1327c11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 11:59:41 +0100 Subject: [PATCH 16/22] Add -> None to main in plot_memory_timeline.py --- tests/profiling/plot_memory_timeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/profiling/plot_memory_timeline.py b/tests/profiling/plot_memory_timeline.py index 0f792be1..3b15256f 100644 --- a/tests/profiling/plot_memory_timeline.py +++ b/tests/profiling/plot_memory_timeline.py @@ -112,7 +112,7 @@ def plot_memory_timelines(experiment: str, folders: list[str]) -> None: print("Plot saved successfully!") -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="Plot memory timeline from profiling traces.") parser.add_argument( "experiment", From 040570d934729b9e10b25f0cf8c2c6cec0827a5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 12:00:31 +0100 Subject: [PATCH 17/22] Add return type annotation in MemoryFrame.from_event --- tests/profiling/plot_memory_timeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/profiling/plot_memory_timeline.py b/tests/profiling/plot_memory_timeline.py index 3b15256f..4c2d9439 100644 --- a/tests/profiling/plot_memory_timeline.py +++ b/tests/profiling/plot_memory_timeline.py @@ -24,7 +24,7 @@ class MemoryFrame: device_id: int # -1 for CPU, 0+ for CUDA devices @staticmethod - def from_event(event: dict): + def from_event(event: dict) -> "MemoryFrame": args = event["args"] return MemoryFrame( timestamp=event["ts"], From 7837c10153f041c7cf6a179d7d0173c05bb35046 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 12:01:36 +0100 Subject: [PATCH 18/22] Add type annotations in static_plotter and rename a variable --- tests/plots/static_plotter.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/plots/static_plotter.py b/tests/plots/static_plotter.py index 928d8eea..de5d6836 100644 --- a/tests/plots/static_plotter.py +++ b/tests/plots/static_plotter.py @@ -25,13 +25,13 @@ def main( *, - gradients=False, - cone=False, - projections=False, - upgrad=False, - mean=False, - dual_proj=False, - mgda=False, + gradients: bool = False, + cone: bool = False, + projections: bool = False, + upgrad: bool = False, + mean: bool = False, + dual_proj: bool = False, + mgda: bool = False, ) -> None: angle1 = 2.6 angle2 = 0.3277 @@ -76,13 +76,13 @@ def main( if cone: filename += "_cone" start_angle, opening = compute_2d_non_conflicting_cone(matrix.numpy()) - cone = make_cone_scatter( + cone_scatter = make_cone_scatter( start_angle, opening, label="Non-conflicting cone", printable=False, ) - fig.add_trace(cone) + fig.add_trace(cone_scatter) if projections: filename += "_projections" From 0fe3b9bdc0c1bf2c3462c4269911afaf0d2dcd98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 12:03:29 +0100 Subject: [PATCH 19/22] Add noqa comment in test_lightning_integration --- tests/doc/test_rst.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index fcd6f392..90dc099a 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -238,7 +238,7 @@ def __init__(self) -> None: self.task2_head = Linear(3, 1) self.automatic_optimization = False - def training_step(self, batch, batch_idx) -> None: + def training_step(self, batch, batch_idx) -> None: # noqa: ANN001 input, target1, target2 = batch features = self.feature_extractor(input) From 693db5f2bf5e80bb0d3f969045f58971dbce0ca5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 12:05:11 +0100 Subject: [PATCH 20/22] Add type annotation to update_gradient_coordinate --- tests/plots/interactive_plotter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 0d26a1e6..2a945f93 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -120,7 +120,7 @@ def update_seed(value: int) -> Figure: *gradient_slider_inputs, prevent_initial_call=True, ) - def update_gradient_coordinate(*values) -> Figure: + def update_gradient_coordinate(*values: str) -> Figure: values_ = [float(value) for value in values] for j in range(len(values_) // 2): From 802b7776231dea441ffdfa82c3e0d5f5f4175ebe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 12:08:46 +0100 Subject: [PATCH 21/22] Add annotation as Any for args and kwargs --- tests/utils/architectures.py | 4 ++-- tests/utils/asserts.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index d324ebff..20ce71cd 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar import torch import torchvision @@ -14,7 +14,7 @@ class ModuleFactory(Generic[_T]): - def __init__(self, architecture: type[_T], *args, **kwargs) -> None: + def __init__(self, architecture: type[_T], *args: Any, **kwargs: Any) -> None: self.architecture: type[_T] = architecture self.args = args self.kwargs = kwargs diff --git a/tests/utils/asserts.py b/tests/utils/asserts.py index 836110a8..392422cf 100644 --- a/tests/utils/asserts.py +++ b/tests/utils/asserts.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from torch import Tensor from torch.testing import assert_close @@ -16,7 +18,7 @@ def assert_has_no_jac(t: Tensor) -> None: assert not is_tensor_with_jac(t) -def assert_jac_close(t: Tensor, expected_jac: Tensor, **kwargs) -> None: +def assert_jac_close(t: Tensor, expected_jac: Tensor, **kwargs: Any) -> None: assert is_tensor_with_jac(t) assert_close(t.jac, expected_jac, **kwargs) @@ -29,12 +31,12 @@ def assert_has_no_grad(t: Tensor) -> None: assert t.grad is None -def assert_grad_close(t: Tensor, expected_grad: Tensor, **kwargs) -> None: +def assert_grad_close(t: Tensor, expected_grad: Tensor, **kwargs: Any) -> None: assert t.grad is not None assert_close(t.grad, expected_grad, **kwargs) -def assert_is_psd_matrix(matrix: Tensor, **kwargs) -> None: +def assert_is_psd_matrix(matrix: Tensor, **kwargs: Any) -> None: assert is_psd_matrix(matrix) assert_close(matrix, matrix.mH, **kwargs) @@ -44,7 +46,7 @@ def assert_is_psd_matrix(matrix: Tensor, **kwargs) -> None: assert_close(eig_vals, expected_eig_vals, **kwargs) -def assert_is_psd_tensor(t: Tensor, **kwargs) -> None: +def assert_is_psd_tensor(t: Tensor, **kwargs: Any) -> None: assert is_psd_tensor(t) matrix = flatten(t) assert_is_psd_matrix(matrix, **kwargs) From bab899531e94aedbd677d021ae2d22271da93d88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 19 Feb 2026 12:17:55 +0100 Subject: [PATCH 22/22] Add -> to test_noncontiguous_jac --- tests/unit/autojac/test_jac_to_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 2e9dca4b..b8ea5c6c 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -103,7 +103,7 @@ def test_jacs_are_freed(retain_jac: bool) -> None: check(t2) -def test_noncontiguous_jac(): +def test_noncontiguous_jac() -> None: """Tests that jac_to_grad works when the .jac field is non-contiguous.""" aggregator = UPGrad()