Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a softmax-projected gradient step on a cross-batch matrix `G = J_1 @ J_2.T`, computed from two independent mini-batches using `autojac.jac`.
- Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature
Aggregation and Geometric Loss Strategy for Multi-Task
Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf),
Expand Down
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Abstract base classes
krum.rst
mean.rst
mgda.rst
modo.rst
nash_mtl.rst
pcgrad.rst
random.rst
Expand Down
7 changes: 7 additions & 0 deletions docs/source/docs/aggregation/modo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

MoDo
====

.. autoclass:: torchjd.aggregation.MoDoWeighting
:members: __call__, reset
2 changes: 2 additions & 0 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from ._mean import Mean, MeanWeighting
from ._mgda import MGDA, MGDAWeighting
from ._mixins import Stateful
from ._modo import MoDoWeighting
from ._nash_mtl import NashMTL
from ._pcgrad import PCGrad, PCGradWeighting
from ._random import Random, RandomWeighting
Expand Down Expand Up @@ -87,6 +88,7 @@
"MeanWeighting",
"MGDA",
"MGDAWeighting",
"MoDoWeighting",
"NashMTL",
"PCGrad",
"PCGradWeighting",
Expand Down
162 changes: 162 additions & 0 deletions src/torchjd/aggregation/_modo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from __future__ import annotations

from typing import cast

import torch
from torch import Tensor

from torchjd.aggregation._mixins import Stateful, _NonDifferentiable
from torchjd.linalg import Matrix

from ._weighting_bases import _MatrixWeighting


class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Three-Way
Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance
<https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf>`_ (JMLR 2024).

.. warning::
The input matrix must be :math:`G = J_1 J_2^\top`, computed from two **independent**
mini-batches via :func:`torchjd.autojac.jac`. Using a single-batch Gramian
(:math:`J_1 J_1^\top`) breaks the convergence guarantee. See the usage examples below.

:param gamma: Learning rate of the task-weight update. Must be positive.
:param rho: Non-negative :math:`\ell_2` regularisation coefficient.

.. admonition:: Example (two batches per step)

The following example reproduces basic MoDo using two independent mini-batches per step.

.. code-block:: python

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import MoDoWeighting
from torchjd.autojac import jac

model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1))
optimizer = SGD(model.parameters())
criterion = MSELoss(reduction="none")
weighting = MoDoWeighting(gamma=0.1, rho=0.0)
params = list(model.parameters())

# loader_1 and loader_2 must yield independent draws of the same size.
for batch_1, batch_2 in zip(loader_1, loader_2):
input_1, target_1 = batch_1
input_2, target_2 = batch_2

losses_1 = criterion(model(input_1).squeeze(dim=1), target_1)
jacs_1 = jac(losses_1, params)
J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1)

# retain_graph=True keeps the graph for the backward step below.
losses_2 = criterion(model(input_2).squeeze(dim=1), target_2)
jacs_2 = jac(losses_2, params, retain_graph=True)
J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1)

G = J_1 @ J_2.T
weights = weighting(G)

losses_2.backward(weights)
optimizer.step()
optimizer.zero_grad()

.. admonition:: Example (three batches per step)

The following example reproduces basic MoDo using three independent mini-batches per step,
keeping the :math:`\lambda` update and the parameter update on separate draws.

.. code-block:: python

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import MoDoWeighting
from torchjd.autojac import jac

model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1))
optimizer = SGD(model.parameters())
criterion = MSELoss(reduction="none")
weighting = MoDoWeighting(gamma=0.1, rho=0.0)
params = list(model.parameters())

for batch_1, batch_2, batch_3 in zip(loader_1, loader_2, loader_3):
input_1, target_1 = batch_1
input_2, target_2 = batch_2
input_3, target_3 = batch_3

losses_1 = criterion(model(input_1).squeeze(dim=1), target_1)
jacs_1 = jac(losses_1, params)
J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1)

losses_2 = criterion(model(input_2).squeeze(dim=1), target_2)
jacs_2 = jac(losses_2, params)
J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1)

G = J_1 @ J_2.T
weights = weighting(G)

losses_3 = criterion(model(input_3).squeeze(dim=1), target_3)
losses_3.backward(weights)
optimizer.step()
optimizer.zero_grad()
"""
Comment on lines +33 to +109
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to change if we go for my suggestion.


def __init__(self, gamma: float = 0.1, rho: float = 0.0) -> None:
super().__init__()
self.gamma = gamma
self.rho = rho
self._lambda: Tensor | None = None
self._state_key: tuple[int, torch.dtype, torch.device] | None = None

@property
def gamma(self) -> float:
return self._gamma

@gamma.setter
def gamma(self, value: float) -> None:
if value <= 0.0:
raise ValueError(f"Attribute `gamma` must be positive. Found gamma={value!r}.")
self._gamma = value

@property
def rho(self) -> float:
return self._rho

@rho.setter
def rho(self, value: float) -> None:
if value < 0.0:
raise ValueError(f"Attribute `rho` must be non-negative. Found rho={value!r}.")
self._rho = value

def reset(self) -> None:
"""Clears the stored task weights so the next forward starts from uniform."""

self._lambda = None
self._state_key = None

def forward(self, matrix: Matrix, /) -> Tensor:
self._ensure_state(matrix)
lambd = cast(Tensor, self._lambda)

grad = matrix @ lambd + self._rho * lambd
lambd = torch.softmax(lambd - self._gamma * grad, dim=-1)

self._lambda = lambd
return lambd

def _ensure_state(self, matrix: Matrix) -> None:
key = (matrix.shape[0], matrix.dtype, matrix.device)
if self._state_key == key and self._lambda is not None:
return
self._lambda = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0])
self._state_key = key

def __repr__(self) -> str:
return f"{self.__class__.__name__}(gamma={self.gamma!r}, rho={self.rho!r})"
159 changes: 159 additions & 0 deletions tests/unit/aggregation/test_modo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import torch
from pytest import raises
from torch.testing import assert_close
from utils.tensors import randn_, tensor_

from torchjd.aggregation._modo import MoDoWeighting


def test_representations() -> None:
W = MoDoWeighting(gamma=0.1, rho=0.05)
assert repr(W) == "MoDoWeighting(gamma=0.1, rho=0.05)"


def test_reset_restores_first_step_behavior() -> None:
J = randn_((3, 8))
G = J @ J.T
W = MoDoWeighting(gamma=0.1)
first = W(G)
W(G)
W.reset()
assert_close(first, W(G))


def test_gamma_setter_accepts_valid() -> None:
W = MoDoWeighting()
W.gamma = 0.01
assert W.gamma == 0.01
W.gamma = 0.1
assert W.gamma == 0.1
W.gamma = 1.0
assert W.gamma == 1.0


def test_gamma_setter_rejects_non_positive() -> None:
W = MoDoWeighting()
with raises(ValueError, match="gamma"):
W.gamma = 0.0
with raises(ValueError, match="gamma"):
W.gamma = -0.1


def test_rho_setter_accepts_valid() -> None:
W = MoDoWeighting()
W.rho = 0.0
assert W.rho == 0.0
W.rho = 0.1
assert W.rho == 0.1


def test_rho_setter_rejects_negative() -> None:
W = MoDoWeighting()
with raises(ValueError, match="rho"):
W.rho = -0.1


def test_output_lies_on_simplex() -> None:
"""The softmax projection ensures the weights sum to 1 and are non-negative."""

J = randn_((4, 10))
G = J @ J.T
W = MoDoWeighting(gamma=0.1, rho=0.05)
weights = W(G)
assert weights.shape == (4,)
assert (weights >= 0).all()
assert_close(weights.sum(), tensor_(1.0))


def test_small_gamma_stays_near_uniform() -> None:
"""With a tiny gamma, one step barely moves lambda from the uniform initialisation."""

J = randn_((3, 8))
G = J @ J.T
m = J.shape[0]
W = MoDoWeighting(gamma=1e-8)
uniform = tensor_([1.0 / m] * m)
assert_close(W(G), uniform, atol=1e-6, rtol=1e-6)


def test_update_recurrence() -> None:
"""Verify one step of the softmax-projected gradient update by hand."""

gamma = 0.1
rho = 0.05
J = randn_((3, 8))
G = J @ J.T
m = J.shape[0]

W = MoDoWeighting(gamma=gamma, rho=rho)
lambda_0 = tensor_([1.0 / m] * m)
grad = G @ lambda_0 + rho * lambda_0
expected = torch.softmax(lambda_0 - gamma * grad, dim=-1)

assert_close(W(G), expected)


def test_two_consecutive_steps() -> None:
"""Verify two consecutive steps of the softmax-projected gradient update."""

gamma = 0.1
rho = 0.0
J1 = randn_((3, 8))
J2 = randn_((3, 8))
G1 = J1 @ J1.T
G2 = J2 @ J2.T
m = J1.shape[0]

W = MoDoWeighting(gamma=gamma, rho=rho)

lambda_0 = tensor_([1.0 / m] * m)
grad_1 = G1 @ lambda_0 + rho * lambda_0
lambda_1 = torch.softmax(lambda_0 - gamma * grad_1, dim=-1)

grad_2 = G2 @ lambda_1 + rho * lambda_1
lambda_2 = torch.softmax(lambda_1 - gamma * grad_2, dim=-1)

assert_close(W(G1), lambda_1)
assert_close(W(G2), lambda_2)


def test_changing_m_auto_resets() -> None:
"""When the number of objectives changes, the state is re-initialised to uniform."""

W = MoDoWeighting(gamma=0.1)
W(randn_((3, 8)) @ randn_((3, 8)).T)
# After a state-resetting call with m=2, the first output should equal the uniform step's output.
fresh = MoDoWeighting(gamma=0.1)
J = randn_((2, 8))
G = J @ J.T
assert_close(W(G), fresh(G))


def test_non_differentiable() -> None:
"""The _NonDifferentiable mixin must prevent autograd graph construction."""

G = randn_((3, 8)) @ randn_((3, 8)).T
G.requires_grad_(True)
W = MoDoWeighting()
weights = W(G)
assert not weights.requires_grad


def test_non_symmetric_input() -> None:
"""MoDoWeighting must accept and correctly process a non-symmetric cross-batch matrix."""

gamma = 0.1
rho = 0.05
J1 = randn_((3, 8))
J2 = randn_((3, 8))
G = J1 @ J2.T # not symmetric, not PSD in general
m = J1.shape[0]

W = MoDoWeighting(gamma=gamma, rho=rho)
lambda_0 = tensor_([1.0 / m] * m)
grad = G @ lambda_0 + rho * lambda_0
expected = torch.softmax(lambda_0 - gamma * grad, dim=-1)

assert_close(W(G), expected)
assert W(G).shape == (m,)
assert (W(G) >= 0).all()
Loading