From 481bf67e5ba0b3daee91afce230a326f2d9ef78c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 15 Apr 2026 11:52:28 +0200 Subject: [PATCH 1/4] Make GramianWeightedAggregator generic --- src/torchjd/aggregation/_aggregator_bases.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index d4be05e99..1eda917ff 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Generic, TypeVar from torch import Tensor, nn @@ -68,7 +69,10 @@ def forward(self, matrix: Matrix, /) -> Tensor: return vector -class GramianWeightedAggregator(WeightedAggregator): +_T = TypeVar("_T", bound=Weighting[PSDMatrix]) + + +class GramianWeightedAggregator(WeightedAggregator, Generic[_T]): """ WeightedAggregator that computes the gramian of the input jacobian matrix before applying a Weighting to it. @@ -77,6 +81,6 @@ class GramianWeightedAggregator(WeightedAggregator): gramian. """ - def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None: + def __init__(self, gramian_weighting: _T) -> None: super().__init__(gramian_weighting << compute_gramian) self.gramian_weighting = gramian_weighting From 4fb16da935ec050a7cd4347251c9b06e9caca18c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 15 Apr 2026 11:52:39 +0200 Subject: [PATCH 2/4] Add CAGrad setters --- src/torchjd/aggregation/_cagrad.py | 95 +++++++++++++++++++----------- 1 file changed, 59 insertions(+), 36 deletions(-) diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 6731178bb..9daf70bd0 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -18,38 +18,6 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class CAGrad(GramianWeightedAggregator): - """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of - `Conflict-Averse Gradient Descent for Multi-task Learning - `_. - - :param c: The scale of the radius of the ball constraint. - :param norm_eps: A small value to avoid division by zero when normalizing. - - .. note:: - This aggregator is not installed by default. When not installed, trying to import it should - result in the following error: - ``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``. - To install it, use ``pip install "torchjd[cagrad]"``. - """ - - def __init__(self, c: float, norm_eps: float = 0.0001) -> None: - super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) - self._c = c - self._norm_eps = norm_eps - - # This prevents considering the computed weights as constant w.r.t. the matrix. - self.register_full_backward_pre_hook(raise_non_differentiable_error) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})" - - def __str__(self) -> str: - c_str = str(self._c).rstrip("0") - return f"CAGrad{c_str}" - - class CAGradWeighting(Weighting[PSDMatrix]): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of @@ -69,13 +37,22 @@ class CAGradWeighting(Weighting[PSDMatrix]): def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__() - - if c < 0.0: - raise ValueError(f"Parameter `c` should be a non-negative float. Found `c = {c}`.") - self.c = c self.norm_eps = norm_eps + @property + def c(self) -> float: + return self._c + + @c.setter + def c(self, value: float) -> None: + if value < 0.0: + raise ValueError( + f"Parameter `value` should be a non-negative float. Found `value = {value}`." + ) + + self._c = value + def forward(self, gramian: PSDMatrix, /) -> Tensor: U, S, _ = torch.svd(normalize(gramian, self.norm_eps)) @@ -104,3 +81,49 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype) return weights + + +class CAGrad(GramianWeightedAggregator[CAGradWeighting]): + """ + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of + `Conflict-Averse Gradient Descent for Multi-task Learning + `_. + + :param c: The scale of the radius of the ball constraint. + :param norm_eps: A small value to avoid division by zero when normalizing. + + .. note:: + This aggregator is not installed by default. When not installed, trying to import it should + result in the following error: + ``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``. + To install it, use ``pip install "torchjd[cagrad]"``. + """ + + def __init__(self, c: float, norm_eps: float = 0.0001) -> None: + super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) + self._c = c + self._norm_eps = norm_eps + + # This prevents considering the computed weights as constant w.r.t. the matrix. + self.register_full_backward_pre_hook(raise_non_differentiable_error) + + @property + def c(self) -> float: + return self._c + + @c.setter + def c(self, value: float) -> None: + if value < 0.0: + raise ValueError( + f"Parameter `value` should be a non-negative float. Found `value = {value}`." + ) + + self._c = value + self.gramian_weighting.c = value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})" + + def __str__(self) -> str: + c_str = str(self._c).rstrip("0") + return f"CAGrad{c_str}" From de6ef5233fab0fb559274a8e379570d4e4dad6ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 15 Apr 2026 12:05:24 +0200 Subject: [PATCH 3/4] Fix setter, call it properly --- src/torchjd/aggregation/_cagrad.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 9daf70bd0..6e0ab8c9d 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -101,7 +101,7 @@ class CAGrad(GramianWeightedAggregator[CAGradWeighting]): def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) - self._c = c + self.c = c self._norm_eps = norm_eps # This prevents considering the computed weights as constant w.r.t. the matrix. @@ -113,13 +113,8 @@ def c(self) -> float: @c.setter def c(self, value: float) -> None: - if value < 0.0: - raise ValueError( - f"Parameter `value` should be a non-negative float. Found `value = {value}`." - ) - - self._c = value self.gramian_weighting.c = value + self._c = value def __repr__(self) -> str: return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})" From 9abaedd1f938516817ccea62449c1faa02cef49f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 15 Apr 2026 13:42:18 +0200 Subject: [PATCH 4/4] Make TypeVar covariant --- src/torchjd/aggregation/_aggregator_bases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 1eda917ff..4abd64664 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -69,7 +69,7 @@ def forward(self, matrix: Matrix, /) -> Tensor: return vector -_T = TypeVar("_T", bound=Weighting[PSDMatrix]) +_T = TypeVar("_T", covariant=True, bound=Weighting[PSDMatrix]) class GramianWeightedAggregator(WeightedAggregator, Generic[_T]):