From fac2c1f26a14ae3fdf365c74971a649466d66b79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 15 Apr 2026 12:21:07 +0200 Subject: [PATCH 1/8] Add ResettableMixin * Make NashMTL, NashMTLWeighting, GradVac and GradVacWeighting inherit from it --- src/torchjd/aggregation/_gradvac.py | 5 +++-- src/torchjd/aggregation/_mixins.py | 9 +++++++++ src/torchjd/aggregation/_nash_mtl.py | 9 +++++---- 3 files changed, 17 insertions(+), 6 deletions(-) create mode 100644 src/torchjd/aggregation/_mixins.py diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 57a08964..4b877d4c 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -6,13 +6,14 @@ from torch import Tensor from torchjd._linalg import PSDMatrix +from torchjd.aggregation._mixins import ResettableMixin from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import Weighting -class GradVac(GramianWeightedAggregator): +class GradVac(GramianWeightedAggregator, ResettableMixin): r""" :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task @@ -71,7 +72,7 @@ def __repr__(self) -> str: return f"GradVac(beta={self.beta!r}, eps={self.eps!r})" -class GradVacWeighting(Weighting[PSDMatrix]): +class GradVacWeighting(Weighting[PSDMatrix], ResettableMixin): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.GradVac`. diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py new file mode 100644 index 00000000..42e09c53 --- /dev/null +++ b/src/torchjd/aggregation/_mixins.py @@ -0,0 +1,9 @@ +from abc import ABC, abstractmethod + + +class ResettableMixin(ABC): + """Mixin that resettable classes should inherit from.""" + + @abstractmethod + def reset(self) -> None: + """Resets the internal state.""" diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 1edd5a0d..737d7b26 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -1,15 +1,16 @@ # Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon. # See NOTICES for the full license text. -from typing import cast - from torchjd._linalg import Matrix +from torchjd.aggregation._mixins import ResettableMixin from ._utils.check_dependencies import check_dependencies_are_installed from ._weighting_bases import Weighting check_dependencies_are_installed(["cvxpy", "ecos"]) +from typing import cast + import cvxpy as cp import numpy as np import torch @@ -20,7 +21,7 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class NashMTL(WeightedAggregator): +class NashMTL(WeightedAggregator, ResettableMixin): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. @@ -83,7 +84,7 @@ def __repr__(self) -> str: ) -class _NashMTLWeighting(Weighting[Matrix]): +class _NashMTLWeighting(Weighting[Matrix], ResettableMixin): """ :class:`~torchjd.aggregation.Weighting` that extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining Game From 516cff0e849eecf757b0bb5ca80f6a6763022ac8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 15 Apr 2026 12:25:19 +0200 Subject: [PATCH 2/8] Improve docstring of ResettableMixin --- src/torchjd/aggregation/_mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index 42e09c53..597758cb 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -2,7 +2,7 @@ class ResettableMixin(ABC): - """Mixin that resettable classes should inherit from.""" + """Class implementing a reset method.""" @abstractmethod def reset(self) -> None: From 8a6be19117a2a8cb970367f529a5de5faa3ec5d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 15 Apr 2026 13:30:50 +0200 Subject: [PATCH 3/8] Make ResettableMixin public --- src/torchjd/aggregation/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 93f824e3..76a942d5 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -71,6 +71,7 @@ from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting +from ._mixins import ResettableMixin from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting from ._sum import Sum, SumWeighting @@ -107,6 +108,7 @@ "PCGradWeighting", "Random", "RandomWeighting", + "ResettableMixin", "Sum", "SumWeighting", "TrimmedMean", From 1710c782117f18b1d43f100cc0b291a27ed1a0a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 15 Apr 2026 13:31:09 +0200 Subject: [PATCH 4/8] Add documentation entry for ResettableMixin in aggregation/index.html --- docs/source/docs/aggregation/index.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 64ba6f63..d2d641bf 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -22,6 +22,10 @@ Abstract base classes :undoc-members: :exclude-members: forward +.. autoclass:: torchjd.aggregation.ResettableMixin + :members: + :undoc-members: + .. toctree:: :hidden: From 450c0d3f1f4e3006e9079d8cbd9be7334234822f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 15 Apr 2026 13:34:55 +0200 Subject: [PATCH 5/8] Rename ResettableMixin to Resettable --- docs/source/docs/aggregation/index.rst | 2 +- src/torchjd/aggregation/__init__.py | 4 ++-- src/torchjd/aggregation/_gradvac.py | 6 +++--- src/torchjd/aggregation/_mixins.py | 4 ++-- src/torchjd/aggregation/_nash_mtl.py | 6 +++--- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index d2d641bf..9d58d45d 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -22,7 +22,7 @@ Abstract base classes :undoc-members: :exclude-members: forward -.. autoclass:: torchjd.aggregation.ResettableMixin +.. autoclass:: torchjd.aggregation.Resettable :members: :undoc-members: diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 76a942d5..6edc164f 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -71,7 +71,7 @@ from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting -from ._mixins import ResettableMixin +from ._mixins import Resettable from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting from ._sum import Sum, SumWeighting @@ -108,7 +108,7 @@ "PCGradWeighting", "Random", "RandomWeighting", - "ResettableMixin", + "Resettable", "Sum", "SumWeighting", "TrimmedMean", diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 4b877d4c..87eb7bf0 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -6,14 +6,14 @@ from torch import Tensor from torchjd._linalg import PSDMatrix -from torchjd.aggregation._mixins import ResettableMixin +from torchjd.aggregation._mixins import Resettable from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import Weighting -class GradVac(GramianWeightedAggregator, ResettableMixin): +class GradVac(GramianWeightedAggregator, Resettable): r""" :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task @@ -72,7 +72,7 @@ def __repr__(self) -> str: return f"GradVac(beta={self.beta!r}, eps={self.eps!r})" -class GradVacWeighting(Weighting[PSDMatrix], ResettableMixin): +class GradVacWeighting(Weighting[PSDMatrix], Resettable): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.GradVac`. diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index 597758cb..ffb4b255 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod -class ResettableMixin(ABC): - """Class implementing a reset method.""" +class Resettable(ABC): + """Mixin adding a reset method.""" @abstractmethod def reset(self) -> None: diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 737d7b26..a202b3c2 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -2,7 +2,7 @@ # See NOTICES for the full license text. from torchjd._linalg import Matrix -from torchjd.aggregation._mixins import ResettableMixin +from torchjd.aggregation._mixins import Resettable from ._utils.check_dependencies import check_dependencies_are_installed from ._weighting_bases import Weighting @@ -21,7 +21,7 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class NashMTL(WeightedAggregator, ResettableMixin): +class NashMTL(WeightedAggregator, Resettable): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. @@ -84,7 +84,7 @@ def __repr__(self) -> str: ) -class _NashMTLWeighting(Weighting[Matrix], ResettableMixin): +class _NashMTLWeighting(Weighting[Matrix], Resettable): """ :class:`~torchjd.aggregation.Weighting` that extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining Game From bd72cf3132f4c42b29d895d0a00e907ba8451370 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 15 Apr 2026 13:37:17 +0200 Subject: [PATCH 6/8] Add Resettable link in aggregator and weighting docstrings --- src/torchjd/aggregation/_gradvac.py | 2 ++ src/torchjd/aggregation/_nash_mtl.py | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 87eb7bf0..5068928c 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -15,6 +15,7 @@ class GradVac(GramianWeightedAggregator, Resettable): r""" + :class:`~torchjd.aggregation._mixins.Resettable` :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) @@ -74,6 +75,7 @@ def __repr__(self) -> str: class GradVacWeighting(Weighting[PSDMatrix], Resettable): r""" + :class:`~torchjd.aggregation._mixins.Resettable` :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.GradVac`. diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index a202b3c2..de96b548 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -23,6 +23,7 @@ class NashMTL(WeightedAggregator, Resettable): """ + :class:`~torchjd.aggregation._mixins.Resettable` :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. @@ -86,9 +87,9 @@ def __repr__(self) -> str: class _NashMTLWeighting(Weighting[Matrix], Resettable): """ - :class:`~torchjd.aggregation.Weighting` that extracts weights using the step decision - of Algorithm 1 of `Multi-Task Learning as a Bargaining Game - `_. + :class:`~torchjd.aggregation._mixins.Resettable` :class:`~torchjd.aggregation.Weighting` that + extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining + Game `_. :param n_tasks: The number of tasks, corresponding to the number of rows in the provided matrices. From 6a6e014cbfc6d9f6a3182e848a988111aefca513 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 15 Apr 2026 13:38:22 +0200 Subject: [PATCH 7/8] Rename Resettable to Stateful --- docs/source/docs/aggregation/index.rst | 2 +- src/torchjd/aggregation/__init__.py | 4 ++-- src/torchjd/aggregation/_gradvac.py | 10 +++++----- src/torchjd/aggregation/_mixins.py | 2 +- src/torchjd/aggregation/_nash_mtl.py | 10 +++++----- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 9d58d45d..0dbdab0c 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -22,7 +22,7 @@ Abstract base classes :undoc-members: :exclude-members: forward -.. autoclass:: torchjd.aggregation.Resettable +.. autoclass:: torchjd.aggregation.Stateful :members: :undoc-members: diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 6edc164f..400cfe27 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -71,7 +71,7 @@ from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting -from ._mixins import Resettable +from ._mixins import Stateful from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting from ._sum import Sum, SumWeighting @@ -108,7 +108,7 @@ "PCGradWeighting", "Random", "RandomWeighting", - "Resettable", + "Stateful", "Sum", "SumWeighting", "TrimmedMean", diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 5068928c..efb55f44 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -6,16 +6,16 @@ from torch import Tensor from torchjd._linalg import PSDMatrix -from torchjd.aggregation._mixins import Resettable +from torchjd.aggregation._mixins import Stateful from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import Weighting -class GradVac(GramianWeightedAggregator, Resettable): +class GradVac(GramianWeightedAggregator, Stateful): r""" - :class:`~torchjd.aggregation._mixins.Resettable` + :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) @@ -73,9 +73,9 @@ def __repr__(self) -> str: return f"GradVac(beta={self.beta!r}, eps={self.eps!r})" -class GradVacWeighting(Weighting[PSDMatrix], Resettable): +class GradVacWeighting(Weighting[PSDMatrix], Stateful): r""" - :class:`~torchjd.aggregation._mixins.Resettable` + :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.GradVac`. diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index ffb4b255..8481feab 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod -class Resettable(ABC): +class Stateful(ABC): """Mixin adding a reset method.""" @abstractmethod diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index de96b548..f64f5182 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -2,7 +2,7 @@ # See NOTICES for the full license text. from torchjd._linalg import Matrix -from torchjd.aggregation._mixins import Resettable +from torchjd.aggregation._mixins import Stateful from ._utils.check_dependencies import check_dependencies_are_installed from ._weighting_bases import Weighting @@ -21,9 +21,9 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class NashMTL(WeightedAggregator, Resettable): +class NashMTL(WeightedAggregator, Stateful): """ - :class:`~torchjd.aggregation._mixins.Resettable` + :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. @@ -85,9 +85,9 @@ def __repr__(self) -> str: ) -class _NashMTLWeighting(Weighting[Matrix], Resettable): +class _NashMTLWeighting(Weighting[Matrix], Stateful): """ - :class:`~torchjd.aggregation._mixins.Resettable` :class:`~torchjd.aggregation.Weighting` that + :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` that extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. From fe8e4621f0d70772a6483afbd88233d7d4337e4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 15 Apr 2026 13:46:35 +0200 Subject: [PATCH 8/8] Add information about the stateful mixin in contributing --- CONTRIBUTING.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3847963f..9e55d66a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -210,7 +210,8 @@ implementation of a mathematical aggregator. > [!NOTE] > We also accept stateful aggregators, whose output depends both on the Jacobian and on some -> internal state (which can be affected for example by previous Jacobians). +> internal state (which can be affected for example by previous Jacobians). Such aggregators should +> inherit from the `Stateful` mixin and implement a `reset` method. > [!NOTE] > Some aggregators may depend on something else than the Jacobian. To implement them, please add