diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index d4be05e9..4abd6466 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Generic, TypeVar from torch import Tensor, nn @@ -68,7 +69,10 @@ def forward(self, matrix: Matrix, /) -> Tensor: return vector -class GramianWeightedAggregator(WeightedAggregator): +_T = TypeVar("_T", covariant=True, bound=Weighting[PSDMatrix]) + + +class GramianWeightedAggregator(WeightedAggregator, Generic[_T]): """ WeightedAggregator that computes the gramian of the input jacobian matrix before applying a Weighting to it. @@ -77,6 +81,6 @@ class GramianWeightedAggregator(WeightedAggregator): gramian. """ - def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None: + def __init__(self, gramian_weighting: _T) -> None: super().__init__(gramian_weighting << compute_gramian) self.gramian_weighting = gramian_weighting diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 6731178b..6e0ab8c9 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -18,38 +18,6 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class CAGrad(GramianWeightedAggregator): - """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of - `Conflict-Averse Gradient Descent for Multi-task Learning - `_. - - :param c: The scale of the radius of the ball constraint. - :param norm_eps: A small value to avoid division by zero when normalizing. - - .. note:: - This aggregator is not installed by default. When not installed, trying to import it should - result in the following error: - ``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``. - To install it, use ``pip install "torchjd[cagrad]"``. - """ - - def __init__(self, c: float, norm_eps: float = 0.0001) -> None: - super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) - self._c = c - self._norm_eps = norm_eps - - # This prevents considering the computed weights as constant w.r.t. the matrix. - self.register_full_backward_pre_hook(raise_non_differentiable_error) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})" - - def __str__(self) -> str: - c_str = str(self._c).rstrip("0") - return f"CAGrad{c_str}" - - class CAGradWeighting(Weighting[PSDMatrix]): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of @@ -69,13 +37,22 @@ class CAGradWeighting(Weighting[PSDMatrix]): def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__() - - if c < 0.0: - raise ValueError(f"Parameter `c` should be a non-negative float. Found `c = {c}`.") - self.c = c self.norm_eps = norm_eps + @property + def c(self) -> float: + return self._c + + @c.setter + def c(self, value: float) -> None: + if value < 0.0: + raise ValueError( + f"Parameter `value` should be a non-negative float. Found `value = {value}`." + ) + + self._c = value + def forward(self, gramian: PSDMatrix, /) -> Tensor: U, S, _ = torch.svd(normalize(gramian, self.norm_eps)) @@ -104,3 +81,44 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype) return weights + + +class CAGrad(GramianWeightedAggregator[CAGradWeighting]): + """ + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of + `Conflict-Averse Gradient Descent for Multi-task Learning + `_. + + :param c: The scale of the radius of the ball constraint. + :param norm_eps: A small value to avoid division by zero when normalizing. + + .. note:: + This aggregator is not installed by default. When not installed, trying to import it should + result in the following error: + ``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``. + To install it, use ``pip install "torchjd[cagrad]"``. + """ + + def __init__(self, c: float, norm_eps: float = 0.0001) -> None: + super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) + self.c = c + self._norm_eps = norm_eps + + # This prevents considering the computed weights as constant w.r.t. the matrix. + self.register_full_backward_pre_hook(raise_non_differentiable_error) + + @property + def c(self) -> float: + return self._c + + @c.setter + def c(self, value: float) -> None: + self.gramian_weighting.c = value + self._c = value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})" + + def __str__(self) -> str: + c_str = str(self._c).rstrip("0") + return f"CAGrad{c_str}"