diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3847963f..9e55d66a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -210,7 +210,8 @@ implementation of a mathematical aggregator. > [!NOTE] > We also accept stateful aggregators, whose output depends both on the Jacobian and on some -> internal state (which can be affected for example by previous Jacobians). +> internal state (which can be affected for example by previous Jacobians). Such aggregators should +> inherit from the `Stateful` mixin and implement a `reset` method. > [!NOTE] > Some aggregators may depend on something else than the Jacobian. To implement them, please add diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 64ba6f63..0dbdab0c 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -22,6 +22,10 @@ Abstract base classes :undoc-members: :exclude-members: forward +.. autoclass:: torchjd.aggregation.Stateful + :members: + :undoc-members: + .. toctree:: :hidden: diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 93f824e3..400cfe27 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -71,6 +71,7 @@ from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting +from ._mixins import Stateful from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting from ._sum import Sum, SumWeighting @@ -107,6 +108,7 @@ "PCGradWeighting", "Random", "RandomWeighting", + "Stateful", "Sum", "SumWeighting", "TrimmedMean", diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 57a08964..efb55f44 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -6,14 +6,16 @@ from torch import Tensor from torchjd._linalg import PSDMatrix +from torchjd.aggregation._mixins import Stateful from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import Weighting -class GradVac(GramianWeightedAggregator): +class GradVac(GramianWeightedAggregator, Stateful): r""" + :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) @@ -71,8 +73,9 @@ def __repr__(self) -> str: return f"GradVac(beta={self.beta!r}, eps={self.eps!r})" -class GradVacWeighting(Weighting[PSDMatrix]): +class GradVacWeighting(Weighting[PSDMatrix], Stateful): r""" + :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.GradVac`. diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py new file mode 100644 index 00000000..8481feab --- /dev/null +++ b/src/torchjd/aggregation/_mixins.py @@ -0,0 +1,9 @@ +from abc import ABC, abstractmethod + + +class Stateful(ABC): + """Mixin adding a reset method.""" + + @abstractmethod + def reset(self) -> None: + """Resets the internal state.""" diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 1edd5a0d..f64f5182 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -1,15 +1,16 @@ # Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon. # See NOTICES for the full license text. -from typing import cast - from torchjd._linalg import Matrix +from torchjd.aggregation._mixins import Stateful from ._utils.check_dependencies import check_dependencies_are_installed from ._weighting_bases import Weighting check_dependencies_are_installed(["cvxpy", "ecos"]) +from typing import cast + import cvxpy as cp import numpy as np import torch @@ -20,8 +21,9 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class NashMTL(WeightedAggregator): +class NashMTL(WeightedAggregator, Stateful): """ + :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. @@ -83,11 +85,11 @@ def __repr__(self) -> str: ) -class _NashMTLWeighting(Weighting[Matrix]): +class _NashMTLWeighting(Weighting[Matrix], Stateful): """ - :class:`~torchjd.aggregation.Weighting` that extracts weights using the step decision - of Algorithm 1 of `Multi-Task Learning as a Bargaining Game - `_. + :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` that + extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining + Game `_. :param n_tasks: The number of tasks, corresponding to the number of rows in the provided matrices.