Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar

from torch import Tensor, nn

Expand Down Expand Up @@ -68,7 +69,10 @@ def forward(self, matrix: Matrix, /) -> Tensor:
return vector


class GramianWeightedAggregator(WeightedAggregator):
_T = TypeVar("_T", covariant=True, bound=Weighting[PSDMatrix])


class GramianWeightedAggregator(WeightedAggregator, Generic[_T]):
"""
WeightedAggregator that computes the gramian of the input jacobian matrix before applying a
Weighting to it.
Expand All @@ -77,6 +81,6 @@ class GramianWeightedAggregator(WeightedAggregator):
gramian.
"""

def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None:
def __init__(self, gramian_weighting: _T) -> None:
super().__init__(gramian_weighting << compute_gramian)
self.gramian_weighting = gramian_weighting
90 changes: 54 additions & 36 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,6 @@
from ._utils.non_differentiable import raise_non_differentiable_error


class CAGrad(GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
`Conflict-Averse Gradient Descent for Multi-task Learning
<https://arxiv.org/pdf/2110.14048.pdf>`_.

:param c: The scale of the radius of the ball constraint.
:param norm_eps: A small value to avoid division by zero when normalizing.

.. note::
This aggregator is not installed by default. When not installed, trying to import it should
result in the following error:
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
To install it, use ``pip install "torchjd[cagrad]"``.
"""

def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
self._c = c
self._norm_eps = norm_eps

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"

def __str__(self) -> str:
c_str = str(self._c).rstrip("0")
return f"CAGrad{c_str}"


class CAGradWeighting(Weighting[PSDMatrix]):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
Expand All @@ -69,13 +37,22 @@

def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__()

if c < 0.0:
raise ValueError(f"Parameter `c` should be a non-negative float. Found `c = {c}`.")

self.c = c
self.norm_eps = norm_eps

@property
def c(self) -> float:
return self._c

@c.setter
def c(self, value: float) -> None:
if value < 0.0:
raise ValueError(
f"Parameter `value` should be a non-negative float. Found `value = {value}`."
)

self._c = value

def forward(self, gramian: PSDMatrix, /) -> Tensor:
U, S, _ = torch.svd(normalize(gramian, self.norm_eps))

Expand Down Expand Up @@ -104,3 +81,44 @@
weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)

return weights


class CAGrad(GramianWeightedAggregator[CAGradWeighting]):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
`Conflict-Averse Gradient Descent for Multi-task Learning
<https://arxiv.org/pdf/2110.14048.pdf>`_.

:param c: The scale of the radius of the ball constraint.
:param norm_eps: A small value to avoid division by zero when normalizing.

.. note::
This aggregator is not installed by default. When not installed, trying to import it should
result in the following error:
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
To install it, use ``pip install "torchjd[cagrad]"``.
"""

def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
self.c = c
self._norm_eps = norm_eps

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def c(self) -> float:
return self._c

Check warning on line 112 in src/torchjd/aggregation/_cagrad.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_cagrad.py#L112

Added line #L112 was not covered by tests

@c.setter
def c(self, value: float) -> None:
self.gramian_weighting.c = value
self._c = value

def __repr__(self) -> str:
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"

def __str__(self) -> str:
c_str = str(self._c).rstrip("0")
return f"CAGrad{c_str}"
Loading