From b981d098ef3828e94880f8e10233921a810e45f5 Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 28 May 2026 23:47:35 -0400 Subject: [PATCH 1/3] feat(aggregation): Add MoDoWeighting --- CHANGELOG.md | 6 + docs/source/docs/aggregation/index.rst | 1 + docs/source/docs/aggregation/modo.rst | 7 ++ src/torchjd/aggregation/__init__.py | 2 + src/torchjd/aggregation/_modo.py | 143 +++++++++++++++++++++++ tests/unit/aggregation/test_modo.py | 153 +++++++++++++++++++++++++ 6 files changed, 312 insertions(+) create mode 100644 docs/source/docs/aggregation/modo.rst create mode 100644 src/torchjd/aggregation/_modo.py create mode 100644 tests/unit/aggregation/test_modo.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9204433a..7ca370c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Added + +- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization,Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a + softmax-projected gradient step on the Gramian, intended to be composed with `autogram.Engine` + in a two-batch training loop. + ## [0.12.0] - 2026-05-28 ### Added diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 13e405cb..66d74570 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -41,6 +41,7 @@ Abstract base classes krum.rst mean.rst mgda.rst + modo.rst nash_mtl.rst pcgrad.rst random.rst diff --git a/docs/source/docs/aggregation/modo.rst b/docs/source/docs/aggregation/modo.rst new file mode 100644 index 00000000..98b8d515 --- /dev/null +++ b/docs/source/docs/aggregation/modo.rst @@ -0,0 +1,7 @@ +:hide-toc: + +MoDo +==== + +.. autoclass:: torchjd.aggregation.MoDoWeighting + :members: __call__, reset diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 1814d320..92bbadec 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -53,6 +53,7 @@ from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting from ._mixins import Stateful +from ._modo import MoDoWeighting from ._nash_mtl import NashMTL from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting @@ -87,6 +88,7 @@ "MeanWeighting", "MGDA", "MGDAWeighting", + "MoDoWeighting", "NashMTL", "PCGrad", "PCGradWeighting", diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py new file mode 100644 index 00000000..d24629b3 --- /dev/null +++ b/src/torchjd/aggregation/_modo.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor + +from torchjd.aggregation._mixins import Stateful +from torchjd.linalg import PSDMatrix + +from ._weighting_bases import _GramianWeighting + + +class MoDoWeighting(_GramianWeighting, Stateful): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] implementing the + task-weight update from `Three-Way Trade-Off in Multi-Objective Learning: Optimization, + Generalization and Conflict-Avoidance `_ + (JMLR 2024), commonly referred to as MoDo (Multi-Objective gradient with Double sampling). + + At each call, the weights :math:`\lambda` are updated by a projected gradient step on + :math:`\lambda^\top G \lambda + \rho \|\lambda\|^2` where :math:`G = G_1 G_1^\top` is the + Gramian of the first mini-batch's Jacobian: + + .. math:: + + \lambda_{t+1} = \operatorname{softmax}\!\bigl( + \lambda_t - \gamma \cdot (G \lambda_t + \rho \lambda_t) + \bigr) + + The paper specifies hard simplex projection :math:`\Pi_\Delta`; we follow the `official + LibMTL implementation `_ and use + :func:`torch.softmax` as the projection step. + + The state :math:`\lambda_{t-1}` is initialised lazily to the uniform vector + :math:`[1/m, \ldots, 1/m]` on the first forward call once :math:`m` is known, and is reset + automatically when :math:`m`, ``dtype`` or ``device`` of the input Gramian changes. Use + :meth:`reset` to manually restart the smoothing from uniform weights. + + .. warning:: + MoDo's convergence guarantees rely on **double sampling**: the Gramian passed to this + weighting must come from a mini-batch that is independent of the one used for the + subsequent parameter update. See the usage example below. + + :param gamma: Learning rate of the task-weight update. Must be positive. + :param rho: Non-negative :math:`\ell_2` regularisation coefficient. + + .. admonition:: Example + + Train a model using MoDo with two independent mini-batches per step. The first batch + drives the :math:`\lambda` update via the Gramian; the second batch drives the parameter + update via the usual backward pass. + + .. code-block:: python + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import MoDoWeighting + from torchjd.autogram import Engine + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = MoDoWeighting(gamma=0.1, rho=0.0) + engine = Engine(model, batch_dim=0) + + # loader_1 and loader_2 must yield independent draws from the same distribution. + for batch_1, batch_2 in zip(loader_1, loader_2): + input_1, target_1 = batch_1 + input_2, target_2 = batch_2 + + # Step 1: Gramian from batch 1 drives the lambda update. + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + gramian = engine.compute_gramian(losses_1) + weights = weighting(gramian) + + # Step 2: backward on batch 2 with those weights drives the parameter update. + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + losses_2.backward(weights) + optimizer.step() + optimizer.zero_grad() + """ + + def __init__(self, gamma: float = 0.1, rho: float = 0.0) -> None: + super().__init__() + self.gamma = gamma + self.rho = rho + self._lambda: Tensor | None = None + self._state_key: tuple[int, torch.dtype, torch.device] | None = None + + @property + def gamma(self) -> float: + return self._gamma + + @gamma.setter + def gamma(self, value: float) -> None: + if value <= 0.0: + raise ValueError(f"Attribute `gamma` must be positive. Found gamma={value!r}.") + self._gamma = value + + @property + def rho(self) -> float: + return self._rho + + @rho.setter + def rho(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `rho` must be non-negative. Found rho={value!r}.") + self._rho = value + + def reset(self) -> None: + """Clears the stored task weights so the next forward starts from uniform.""" + + self._lambda = None + self._state_key = None + + def forward(self, gramian: PSDMatrix, /) -> Tensor: + m = gramian.shape[0] + if m == 0: + return gramian.new_empty((0,)) + + self._ensure_state(gramian) + lambd = cast(Tensor, self._lambda) + + with torch.no_grad(): + grad = gramian @ lambd + self._rho * lambd + lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) + + self._lambda = lambd + return lambd + + def _ensure_state(self, gramian: PSDMatrix) -> None: + key = (gramian.shape[0], gramian.dtype, gramian.device) + if self._state_key == key and self._lambda is not None: + return + self._lambda = gramian.new_full((gramian.shape[0],), 1.0 / gramian.shape[0]) + self._state_key = key + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(gamma={self.gamma!r}, rho={self.rho!r})" diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py new file mode 100644 index 00000000..af80fbab --- /dev/null +++ b/tests/unit/aggregation/test_modo.py @@ -0,0 +1,153 @@ +import torch +from pytest import mark, raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator +from torchjd.aggregation._modo import MoDoWeighting + +from ._asserts import assert_expected_structure +from ._inputs import scaled_matrices, typical_matrices + +gramian_pairs = [ + (GramianWeightedAggregator(MoDoWeighting()), m) for m in typical_matrices + scaled_matrices +] + + +def test_representations() -> None: + W = MoDoWeighting(gamma=0.1, rho=0.05) + assert repr(W) == "MoDoWeighting(gamma=0.1, rho=0.05)" + + +@mark.parametrize(["aggregator", "matrix"], gramian_pairs) +def test_expected_structure_gramian_weighting( + aggregator: GramianWeightedAggregator, matrix: Tensor +) -> None: + assert_expected_structure(aggregator, matrix) + + +def test_reset_restores_first_step_behavior() -> None: + J = randn_((3, 8)) + G = J @ J.T + W = MoDoWeighting(gamma=0.1) + first = W(G) + W(G) + W.reset() + assert_close(first, W(G)) + + +def test_gamma_setter_accepts_valid() -> None: + W = MoDoWeighting() + W.gamma = 0.01 + assert W.gamma == 0.01 + W.gamma = 0.1 + assert W.gamma == 0.1 + W.gamma = 1.0 + assert W.gamma == 1.0 + + +def test_gamma_setter_rejects_non_positive() -> None: + W = MoDoWeighting() + with raises(ValueError, match="gamma"): + W.gamma = 0.0 + with raises(ValueError, match="gamma"): + W.gamma = -0.1 + + +def test_rho_setter_accepts_valid() -> None: + W = MoDoWeighting() + W.rho = 0.0 + assert W.rho == 0.0 + W.rho = 0.1 + assert W.rho == 0.1 + + +def test_rho_setter_rejects_negative() -> None: + W = MoDoWeighting() + with raises(ValueError, match="rho"): + W.rho = -0.1 + + +def test_output_lies_on_simplex() -> None: + """The softmax projection ensures the weights sum to 1 and are non-negative.""" + + J = randn_((4, 10)) + G = J @ J.T + W = MoDoWeighting(gamma=0.1, rho=0.05) + weights = W(G) + assert weights.shape == (4,) + assert (weights >= 0).all() + assert_close(weights.sum(), tensor_(1.0)) + + +def test_small_gamma_stays_near_uniform() -> None: + """With a tiny gamma, one step barely moves lambda from the uniform initialisation.""" + + J = randn_((3, 8)) + G = J @ J.T + m = J.shape[0] + W = MoDoWeighting(gamma=1e-8) + uniform = tensor_([1.0 / m] * m) + assert_close(W(G), uniform, atol=1e-6, rtol=1e-6) + + +def test_update_recurrence() -> None: + """Verify one step of the softmax-projected gradient update by hand.""" + + gamma = 0.1 + rho = 0.05 + J = randn_((3, 8)) + G = J @ J.T + m = J.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + lambda_0 = tensor_([1.0 / m] * m) + grad = G @ lambda_0 + rho * lambda_0 + expected = torch.softmax(lambda_0 - gamma * grad, dim=-1) + + assert_close(W(G), expected) + + +def test_two_consecutive_steps() -> None: + """Verify two consecutive steps of the softmax-projected gradient update.""" + + gamma = 0.1 + rho = 0.0 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G1 = J1 @ J1.T + G2 = J2 @ J2.T + m = J1.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + + lambda_0 = tensor_([1.0 / m] * m) + grad_1 = G1 @ lambda_0 + rho * lambda_0 + lambda_1 = torch.softmax(lambda_0 - gamma * grad_1, dim=-1) + + grad_2 = G2 @ lambda_1 + rho * lambda_1 + lambda_2 = torch.softmax(lambda_1 - gamma * grad_2, dim=-1) + + assert_close(W(G1), lambda_1) + assert_close(W(G2), lambda_2) + + +def test_changing_m_auto_resets() -> None: + """When the number of objectives changes, the state is re-initialised to uniform.""" + + W = MoDoWeighting(gamma=0.1) + W(randn_((3, 8)) @ randn_((3, 8)).T) + # After a state-resetting call with m=2, the first output should equal the uniform step's output. + fresh = MoDoWeighting(gamma=0.1) + J = randn_((2, 8)) + G = J @ J.T + assert_close(W(G), fresh(G)) + + +def test_zero_rows() -> None: + """A (0, 0) Gramian yields an empty weight vector.""" + + W = MoDoWeighting() + weights = W(tensor_([]).reshape(0, 0)) + assert weights.shape == (0,) From b416fbadffeb90f33761c32faa0598a4309e5731 Mon Sep 17 00:00:00 2001 From: Khush Date: Fri, 29 May 2026 10:29:45 -0400 Subject: [PATCH 2/3] refactor(aggregation): Address review feedback on MoDoWeighting --- src/torchjd/aggregation/_modo.py | 31 ++++++++++++----------------- tests/unit/aggregation/test_modo.py | 8 -------- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py index d24629b3..d219b44e 100644 --- a/src/torchjd/aggregation/_modo.py +++ b/src/torchjd/aggregation/_modo.py @@ -5,23 +5,22 @@ import torch from torch import Tensor -from torchjd.aggregation._mixins import Stateful +from torchjd.aggregation._mixins import Stateful, _NonDifferentiable from torchjd.linalg import PSDMatrix from ._weighting_bases import _GramianWeighting -class MoDoWeighting(_GramianWeighting, Stateful): +class MoDoWeighting(_GramianWeighting, Stateful, _NonDifferentiable): r""" :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] implementing the - task-weight update from `Three-Way Trade-Off in Multi-Objective Learning: Optimization, - Generalization and Conflict-Avoidance `_ - (JMLR 2024), commonly referred to as MoDo (Multi-Objective gradient with Double sampling). + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] from `Three-Way + Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance + `_ (JMLR 2024), commonly referred + to as MoDo (Multi-Objective gradient with Double sampling). - At each call, the weights :math:`\lambda` are updated by a projected gradient step on - :math:`\lambda^\top G \lambda + \rho \|\lambda\|^2` where :math:`G = G_1 G_1^\top` is the - Gramian of the first mini-batch's Jacobian: + Given a Gramian :math:`G`, the weights :math:`\lambda` are updated at each call by a + softmax-projected gradient step: .. math:: @@ -36,12 +35,13 @@ class MoDoWeighting(_GramianWeighting, Stateful): The state :math:`\lambda_{t-1}` is initialised lazily to the uniform vector :math:`[1/m, \ldots, 1/m]` on the first forward call once :math:`m` is known, and is reset automatically when :math:`m`, ``dtype`` or ``device`` of the input Gramian changes. Use - :meth:`reset` to manually restart the smoothing from uniform weights. + :meth:`reset` to manually restart from uniform weights. .. warning:: MoDo's convergence guarantees rely on **double sampling**: the Gramian passed to this weighting must come from a mini-batch that is independent of the one used for the - subsequent parameter update. See the usage example below. + subsequent parameter update. The Gramian can be computed efficiently from a batch of + losses using the :class:`~torchjd.autogram.Engine`. See the usage example below. :param gamma: Learning rate of the task-weight update. Must be positive. :param rho: Non-negative :math:`\ell_2` regularisation coefficient. @@ -118,16 +118,11 @@ def reset(self) -> None: self._state_key = None def forward(self, gramian: PSDMatrix, /) -> Tensor: - m = gramian.shape[0] - if m == 0: - return gramian.new_empty((0,)) - self._ensure_state(gramian) lambd = cast(Tensor, self._lambda) - with torch.no_grad(): - grad = gramian @ lambd + self._rho * lambd - lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) + grad = gramian @ lambd + self._rho * lambd + lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) self._lambda = lambd return lambd diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py index af80fbab..9b9193be 100644 --- a/tests/unit/aggregation/test_modo.py +++ b/tests/unit/aggregation/test_modo.py @@ -143,11 +143,3 @@ def test_changing_m_auto_resets() -> None: J = randn_((2, 8)) G = J @ J.T assert_close(W(G), fresh(G)) - - -def test_zero_rows() -> None: - """A (0, 0) Gramian yields an empty weight vector.""" - - W = MoDoWeighting() - weights = W(tensor_([]).reshape(0, 0)) - assert weights.shape == (0,) From 2c6188a8dadd28d17cb988bcf19afe7805165232 Mon Sep 17 00:00:00 2001 From: Khush Date: Sun, 31 May 2026 18:13:08 -0400 Subject: [PATCH 3/3] refactor(aggregation): Address review feedback on MoDoWeighting --- CHANGELOG.md | 4 +- src/torchjd/aggregation/_modo.py | 114 +++++++++++++++++----------- tests/unit/aggregation/test_modo.py | 48 +++++++----- 3 files changed, 101 insertions(+), 65 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 225e65d5..926e641b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,7 @@ changelog does not include internal changes that do not affect the user. ### Added -- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization,Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a - softmax-projected gradient step on the Gramian, intended to be composed with `autogram.Engine` - in a two-batch training loop. +- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a softmax-projected gradient step on a cross-batch matrix `G = J_1 @ J_2.T`, computed from two independent mini-batches using `autojac.jac`. - Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf), diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py index d219b44e..f38e2ad9 100644 --- a/src/torchjd/aggregation/_modo.py +++ b/src/torchjd/aggregation/_modo.py @@ -6,51 +6,29 @@ from torch import Tensor from torchjd.aggregation._mixins import Stateful, _NonDifferentiable -from torchjd.linalg import PSDMatrix +from torchjd.linalg import Matrix -from ._weighting_bases import _GramianWeighting +from ._weighting_bases import _MatrixWeighting -class MoDoWeighting(_GramianWeighting, Stateful, _NonDifferentiable): +class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): r""" :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] from `Three-Way + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance - `_ (JMLR 2024), commonly referred - to as MoDo (Multi-Objective gradient with Double sampling). - - Given a Gramian :math:`G`, the weights :math:`\lambda` are updated at each call by a - softmax-projected gradient step: - - .. math:: - - \lambda_{t+1} = \operatorname{softmax}\!\bigl( - \lambda_t - \gamma \cdot (G \lambda_t + \rho \lambda_t) - \bigr) - - The paper specifies hard simplex projection :math:`\Pi_\Delta`; we follow the `official - LibMTL implementation `_ and use - :func:`torch.softmax` as the projection step. - - The state :math:`\lambda_{t-1}` is initialised lazily to the uniform vector - :math:`[1/m, \ldots, 1/m]` on the first forward call once :math:`m` is known, and is reset - automatically when :math:`m`, ``dtype`` or ``device`` of the input Gramian changes. Use - :meth:`reset` to manually restart from uniform weights. + `_ (JMLR 2024). .. warning:: - MoDo's convergence guarantees rely on **double sampling**: the Gramian passed to this - weighting must come from a mini-batch that is independent of the one used for the - subsequent parameter update. The Gramian can be computed efficiently from a batch of - losses using the :class:`~torchjd.autogram.Engine`. See the usage example below. + The input matrix must be :math:`G = J_1 J_2^\top`, computed from two **independent** + mini-batches via :func:`torchjd.autojac.jac`. Using a single-batch Gramian + (:math:`J_1 J_1^\top`) breaks the convergence guarantee. See the usage examples below. :param gamma: Learning rate of the task-weight update. Must be positive. :param rho: Non-negative :math:`\ell_2` regularisation coefficient. - .. admonition:: Example + .. admonition:: Example (two batches per step) - Train a model using MoDo with two independent mini-batches per step. The first batch - drives the :math:`\lambda` update via the Gramian; the second batch drives the parameter - update via the usual backward pass. + The following example reproduces basic MoDo using two independent mini-batches per step. .. code-block:: python @@ -59,29 +37,75 @@ class MoDoWeighting(_GramianWeighting, Stateful, _NonDifferentiable): from torch.optim import SGD from torchjd.aggregation import MoDoWeighting - from torchjd.autogram import Engine + from torchjd.autojac import jac model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) optimizer = SGD(model.parameters()) criterion = MSELoss(reduction="none") weighting = MoDoWeighting(gamma=0.1, rho=0.0) - engine = Engine(model, batch_dim=0) + params = list(model.parameters()) - # loader_1 and loader_2 must yield independent draws from the same distribution. + # loader_1 and loader_2 must yield independent draws of the same size. for batch_1, batch_2 in zip(loader_1, loader_2): input_1, target_1 = batch_1 input_2, target_2 = batch_2 - # Step 1: Gramian from batch 1 drives the lambda update. losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) - gramian = engine.compute_gramian(losses_1) - weights = weighting(gramian) + jacs_1 = jac(losses_1, params) + J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) - # Step 2: backward on batch 2 with those weights drives the parameter update. + # retain_graph=True keeps the graph for the backward step below. losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + jacs_2 = jac(losses_2, params, retain_graph=True) + J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) + + G = J_1 @ J_2.T + weights = weighting(G) + losses_2.backward(weights) optimizer.step() optimizer.zero_grad() + + .. admonition:: Example (three batches per step) + + The following example reproduces basic MoDo using three independent mini-batches per step, + keeping the :math:`\lambda` update and the parameter update on separate draws. + + .. code-block:: python + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import MoDoWeighting + from torchjd.autojac import jac + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = MoDoWeighting(gamma=0.1, rho=0.0) + params = list(model.parameters()) + + for batch_1, batch_2, batch_3 in zip(loader_1, loader_2, loader_3): + input_1, target_1 = batch_1 + input_2, target_2 = batch_2 + input_3, target_3 = batch_3 + + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + jacs_1 = jac(losses_1, params) + J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) + + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + jacs_2 = jac(losses_2, params) + J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) + + G = J_1 @ J_2.T + weights = weighting(G) + + losses_3 = criterion(model(input_3).squeeze(dim=1), target_3) + losses_3.backward(weights) + optimizer.step() + optimizer.zero_grad() """ def __init__(self, gamma: float = 0.1, rho: float = 0.0) -> None: @@ -117,21 +141,21 @@ def reset(self) -> None: self._lambda = None self._state_key = None - def forward(self, gramian: PSDMatrix, /) -> Tensor: - self._ensure_state(gramian) + def forward(self, matrix: Matrix, /) -> Tensor: + self._ensure_state(matrix) lambd = cast(Tensor, self._lambda) - grad = gramian @ lambd + self._rho * lambd + grad = matrix @ lambd + self._rho * lambd lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) self._lambda = lambd return lambd - def _ensure_state(self, gramian: PSDMatrix) -> None: - key = (gramian.shape[0], gramian.dtype, gramian.device) + def _ensure_state(self, matrix: Matrix) -> None: + key = (matrix.shape[0], matrix.dtype, matrix.device) if self._state_key == key and self._lambda is not None: return - self._lambda = gramian.new_full((gramian.shape[0],), 1.0 / gramian.shape[0]) + self._lambda = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0]) self._state_key = key def __repr__(self) -> str: diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py index 9b9193be..6203c99d 100644 --- a/tests/unit/aggregation/test_modo.py +++ b/tests/unit/aggregation/test_modo.py @@ -1,32 +1,16 @@ import torch -from pytest import mark, raises -from torch import Tensor +from pytest import raises from torch.testing import assert_close from utils.tensors import randn_, tensor_ -from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator from torchjd.aggregation._modo import MoDoWeighting -from ._asserts import assert_expected_structure -from ._inputs import scaled_matrices, typical_matrices - -gramian_pairs = [ - (GramianWeightedAggregator(MoDoWeighting()), m) for m in typical_matrices + scaled_matrices -] - def test_representations() -> None: W = MoDoWeighting(gamma=0.1, rho=0.05) assert repr(W) == "MoDoWeighting(gamma=0.1, rho=0.05)" -@mark.parametrize(["aggregator", "matrix"], gramian_pairs) -def test_expected_structure_gramian_weighting( - aggregator: GramianWeightedAggregator, matrix: Tensor -) -> None: - assert_expected_structure(aggregator, matrix) - - def test_reset_restores_first_step_behavior() -> None: J = randn_((3, 8)) G = J @ J.T @@ -143,3 +127,33 @@ def test_changing_m_auto_resets() -> None: J = randn_((2, 8)) G = J @ J.T assert_close(W(G), fresh(G)) + + +def test_non_differentiable() -> None: + """The _NonDifferentiable mixin must prevent autograd graph construction.""" + + G = randn_((3, 8)) @ randn_((3, 8)).T + G.requires_grad_(True) + W = MoDoWeighting() + weights = W(G) + assert not weights.requires_grad + + +def test_non_symmetric_input() -> None: + """MoDoWeighting must accept and correctly process a non-symmetric cross-batch matrix.""" + + gamma = 0.1 + rho = 0.05 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G = J1 @ J2.T # not symmetric, not PSD in general + m = J1.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + lambda_0 = tensor_([1.0 / m] * m) + grad = G @ lambda_0 + rho * lambda_0 + expected = torch.softmax(lambda_0 - gamma * grad, dim=-1) + + assert_close(W(G), expected) + assert W(G).shape == (m,) + assert (W(G) >= 0).all()