Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ implementation of a mathematical aggregator.

> [!NOTE]
> We also accept stateful aggregators, whose output depends both on the Jacobian and on some
> internal state (which can be affected for example by previous Jacobians).
> internal state (which can be affected for example by previous Jacobians). Such aggregators should
> inherit from the `Stateful` mixin and implement a `reset` method.

> [!NOTE]
> Some aggregators may depend on something else than the Jacobian. To implement them, please add
Expand Down
4 changes: 4 additions & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ Abstract base classes
:undoc-members:
:exclude-members: forward

.. autoclass:: torchjd.aggregation.Stateful
:members:
:undoc-members:


.. toctree::
:hidden:
Expand Down
2 changes: 2 additions & 0 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from ._krum import Krum, KrumWeighting
from ._mean import Mean, MeanWeighting
from ._mgda import MGDA, MGDAWeighting
from ._mixins import Stateful
from ._pcgrad import PCGrad, PCGradWeighting
from ._random import Random, RandomWeighting
from ._sum import Sum, SumWeighting
Expand Down Expand Up @@ -107,6 +108,7 @@
"PCGradWeighting",
"Random",
"RandomWeighting",
"Stateful",
"Sum",
"SumWeighting",
"TrimmedMean",
Expand Down
7 changes: 5 additions & 2 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from torch import Tensor

from torchjd._linalg import PSDMatrix
from torchjd.aggregation._mixins import Stateful

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import Weighting


class GradVac(GramianWeightedAggregator):
class GradVac(GramianWeightedAggregator, Stateful):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
Expand Down Expand Up @@ -71,8 +73,9 @@ def __repr__(self) -> str:
return f"GradVac(beta={self.beta!r}, eps={self.eps!r})"


class GradVacWeighting(Weighting[PSDMatrix]):
class GradVacWeighting(Weighting[PSDMatrix], Stateful):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GradVac`.

Expand Down
9 changes: 9 additions & 0 deletions src/torchjd/aggregation/_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod


class Stateful(ABC):
"""Mixin adding a reset method."""

@abstractmethod
def reset(self) -> None:
"""Resets the internal state."""
16 changes: 9 additions & 7 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon.
# See NOTICES for the full license text.

from typing import cast

from torchjd._linalg import Matrix
from torchjd.aggregation._mixins import Stateful

from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import Weighting

check_dependencies_are_installed(["cvxpy", "ecos"])

from typing import cast

import cvxpy as cp
import numpy as np
import torch
Expand All @@ -20,8 +21,9 @@
from ._utils.non_differentiable import raise_non_differentiable_error


class NashMTL(WeightedAggregator):
class NashMTL(WeightedAggregator, Stateful):
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of
`Multi-Task Learning as a Bargaining Game <https://arxiv.org/pdf/2202.01017.pdf>`_.

Expand Down Expand Up @@ -83,11 +85,11 @@ def __repr__(self) -> str:
)


class _NashMTLWeighting(Weighting[Matrix]):
class _NashMTLWeighting(Weighting[Matrix], Stateful):
"""
:class:`~torchjd.aggregation.Weighting` that extracts weights using the step decision
of Algorithm 1 of `Multi-Task Learning as a Bargaining Game
<https://arxiv.org/pdf/2202.01017.pdf>`_.
:class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` that
extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining
Game <https://arxiv.org/pdf/2202.01017.pdf>`_.

:param n_tasks: The number of tasks, corresponding to the number of rows in the provided
matrices.
Expand Down