diff --git a/CHANGELOG.md b/CHANGELOG.md index ab5a536d..926e641b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a softmax-projected gradient step on a cross-batch matrix `G = J_1 @ J_2.T`, computed from two independent mini-batches using `autojac.jac`. - Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf), diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 13e405cb..66d74570 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -41,6 +41,7 @@ Abstract base classes krum.rst mean.rst mgda.rst + modo.rst nash_mtl.rst pcgrad.rst random.rst diff --git a/docs/source/docs/aggregation/modo.rst b/docs/source/docs/aggregation/modo.rst new file mode 100644 index 00000000..98b8d515 --- /dev/null +++ b/docs/source/docs/aggregation/modo.rst @@ -0,0 +1,7 @@ +:hide-toc: + +MoDo +==== + +.. autoclass:: torchjd.aggregation.MoDoWeighting + :members: __call__, reset diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 1814d320..92bbadec 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -53,6 +53,7 @@ from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting from ._mixins import Stateful +from ._modo import MoDoWeighting from ._nash_mtl import NashMTL from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting @@ -87,6 +88,7 @@ "MeanWeighting", "MGDA", "MGDAWeighting", + "MoDoWeighting", "NashMTL", "PCGrad", "PCGradWeighting", diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py new file mode 100644 index 00000000..f38e2ad9 --- /dev/null +++ b/src/torchjd/aggregation/_modo.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor + +from torchjd.aggregation._mixins import Stateful, _NonDifferentiable +from torchjd.linalg import Matrix + +from ._weighting_bases import _MatrixWeighting + + +class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Three-Way + Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance + `_ (JMLR 2024). + + .. warning:: + The input matrix must be :math:`G = J_1 J_2^\top`, computed from two **independent** + mini-batches via :func:`torchjd.autojac.jac`. Using a single-batch Gramian + (:math:`J_1 J_1^\top`) breaks the convergence guarantee. See the usage examples below. + + :param gamma: Learning rate of the task-weight update. Must be positive. + :param rho: Non-negative :math:`\ell_2` regularisation coefficient. + + .. admonition:: Example (two batches per step) + + The following example reproduces basic MoDo using two independent mini-batches per step. + + .. code-block:: python + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import MoDoWeighting + from torchjd.autojac import jac + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = MoDoWeighting(gamma=0.1, rho=0.0) + params = list(model.parameters()) + + # loader_1 and loader_2 must yield independent draws of the same size. + for batch_1, batch_2 in zip(loader_1, loader_2): + input_1, target_1 = batch_1 + input_2, target_2 = batch_2 + + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + jacs_1 = jac(losses_1, params) + J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) + + # retain_graph=True keeps the graph for the backward step below. + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + jacs_2 = jac(losses_2, params, retain_graph=True) + J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) + + G = J_1 @ J_2.T + weights = weighting(G) + + losses_2.backward(weights) + optimizer.step() + optimizer.zero_grad() + + .. admonition:: Example (three batches per step) + + The following example reproduces basic MoDo using three independent mini-batches per step, + keeping the :math:`\lambda` update and the parameter update on separate draws. + + .. code-block:: python + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import MoDoWeighting + from torchjd.autojac import jac + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = MoDoWeighting(gamma=0.1, rho=0.0) + params = list(model.parameters()) + + for batch_1, batch_2, batch_3 in zip(loader_1, loader_2, loader_3): + input_1, target_1 = batch_1 + input_2, target_2 = batch_2 + input_3, target_3 = batch_3 + + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + jacs_1 = jac(losses_1, params) + J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) + + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + jacs_2 = jac(losses_2, params) + J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) + + G = J_1 @ J_2.T + weights = weighting(G) + + losses_3 = criterion(model(input_3).squeeze(dim=1), target_3) + losses_3.backward(weights) + optimizer.step() + optimizer.zero_grad() + """ + + def __init__(self, gamma: float = 0.1, rho: float = 0.0) -> None: + super().__init__() + self.gamma = gamma + self.rho = rho + self._lambda: Tensor | None = None + self._state_key: tuple[int, torch.dtype, torch.device] | None = None + + @property + def gamma(self) -> float: + return self._gamma + + @gamma.setter + def gamma(self, value: float) -> None: + if value <= 0.0: + raise ValueError(f"Attribute `gamma` must be positive. Found gamma={value!r}.") + self._gamma = value + + @property + def rho(self) -> float: + return self._rho + + @rho.setter + def rho(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `rho` must be non-negative. Found rho={value!r}.") + self._rho = value + + def reset(self) -> None: + """Clears the stored task weights so the next forward starts from uniform.""" + + self._lambda = None + self._state_key = None + + def forward(self, matrix: Matrix, /) -> Tensor: + self._ensure_state(matrix) + lambd = cast(Tensor, self._lambda) + + grad = matrix @ lambd + self._rho * lambd + lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) + + self._lambda = lambd + return lambd + + def _ensure_state(self, matrix: Matrix) -> None: + key = (matrix.shape[0], matrix.dtype, matrix.device) + if self._state_key == key and self._lambda is not None: + return + self._lambda = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0]) + self._state_key = key + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(gamma={self.gamma!r}, rho={self.rho!r})" diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py new file mode 100644 index 00000000..6203c99d --- /dev/null +++ b/tests/unit/aggregation/test_modo.py @@ -0,0 +1,159 @@ +import torch +from pytest import raises +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation._modo import MoDoWeighting + + +def test_representations() -> None: + W = MoDoWeighting(gamma=0.1, rho=0.05) + assert repr(W) == "MoDoWeighting(gamma=0.1, rho=0.05)" + + +def test_reset_restores_first_step_behavior() -> None: + J = randn_((3, 8)) + G = J @ J.T + W = MoDoWeighting(gamma=0.1) + first = W(G) + W(G) + W.reset() + assert_close(first, W(G)) + + +def test_gamma_setter_accepts_valid() -> None: + W = MoDoWeighting() + W.gamma = 0.01 + assert W.gamma == 0.01 + W.gamma = 0.1 + assert W.gamma == 0.1 + W.gamma = 1.0 + assert W.gamma == 1.0 + + +def test_gamma_setter_rejects_non_positive() -> None: + W = MoDoWeighting() + with raises(ValueError, match="gamma"): + W.gamma = 0.0 + with raises(ValueError, match="gamma"): + W.gamma = -0.1 + + +def test_rho_setter_accepts_valid() -> None: + W = MoDoWeighting() + W.rho = 0.0 + assert W.rho == 0.0 + W.rho = 0.1 + assert W.rho == 0.1 + + +def test_rho_setter_rejects_negative() -> None: + W = MoDoWeighting() + with raises(ValueError, match="rho"): + W.rho = -0.1 + + +def test_output_lies_on_simplex() -> None: + """The softmax projection ensures the weights sum to 1 and are non-negative.""" + + J = randn_((4, 10)) + G = J @ J.T + W = MoDoWeighting(gamma=0.1, rho=0.05) + weights = W(G) + assert weights.shape == (4,) + assert (weights >= 0).all() + assert_close(weights.sum(), tensor_(1.0)) + + +def test_small_gamma_stays_near_uniform() -> None: + """With a tiny gamma, one step barely moves lambda from the uniform initialisation.""" + + J = randn_((3, 8)) + G = J @ J.T + m = J.shape[0] + W = MoDoWeighting(gamma=1e-8) + uniform = tensor_([1.0 / m] * m) + assert_close(W(G), uniform, atol=1e-6, rtol=1e-6) + + +def test_update_recurrence() -> None: + """Verify one step of the softmax-projected gradient update by hand.""" + + gamma = 0.1 + rho = 0.05 + J = randn_((3, 8)) + G = J @ J.T + m = J.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + lambda_0 = tensor_([1.0 / m] * m) + grad = G @ lambda_0 + rho * lambda_0 + expected = torch.softmax(lambda_0 - gamma * grad, dim=-1) + + assert_close(W(G), expected) + + +def test_two_consecutive_steps() -> None: + """Verify two consecutive steps of the softmax-projected gradient update.""" + + gamma = 0.1 + rho = 0.0 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G1 = J1 @ J1.T + G2 = J2 @ J2.T + m = J1.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + + lambda_0 = tensor_([1.0 / m] * m) + grad_1 = G1 @ lambda_0 + rho * lambda_0 + lambda_1 = torch.softmax(lambda_0 - gamma * grad_1, dim=-1) + + grad_2 = G2 @ lambda_1 + rho * lambda_1 + lambda_2 = torch.softmax(lambda_1 - gamma * grad_2, dim=-1) + + assert_close(W(G1), lambda_1) + assert_close(W(G2), lambda_2) + + +def test_changing_m_auto_resets() -> None: + """When the number of objectives changes, the state is re-initialised to uniform.""" + + W = MoDoWeighting(gamma=0.1) + W(randn_((3, 8)) @ randn_((3, 8)).T) + # After a state-resetting call with m=2, the first output should equal the uniform step's output. + fresh = MoDoWeighting(gamma=0.1) + J = randn_((2, 8)) + G = J @ J.T + assert_close(W(G), fresh(G)) + + +def test_non_differentiable() -> None: + """The _NonDifferentiable mixin must prevent autograd graph construction.""" + + G = randn_((3, 8)) @ randn_((3, 8)).T + G.requires_grad_(True) + W = MoDoWeighting() + weights = W(G) + assert not weights.requires_grad + + +def test_non_symmetric_input() -> None: + """MoDoWeighting must accept and correctly process a non-symmetric cross-batch matrix.""" + + gamma = 0.1 + rho = 0.05 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G = J1 @ J2.T # not symmetric, not PSD in general + m = J1.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + lambda_0 = tensor_([1.0 / m] * m) + grad = G @ lambda_0 + rho * lambda_0 + expected = torch.softmax(lambda_0 - gamma * grad, dim=-1) + + assert_close(W(G), expected) + assert W(G).shape == (m,) + assert (W(G) >= 0).all()